diff --git a/crates/common/src/mq.rs b/crates/common/src/mq.rs index a886cf6..ca5318f 100644 --- a/crates/common/src/mq.rs +++ b/crates/common/src/mq.rs @@ -215,15 +215,15 @@ where info!(queue = %cfg.queue, prefetch = cfg.prefetch, "consumer started"); tokio::pin!(consumer); - let mut shutdown = Box::pin(shutdown); + let mut shutdown: core::pin::Pin + Send>> = Box::pin(shutdown); - loop { + 'consume: loop { tokio::select! { _ = &mut shutdown => { info!("shutdown requested; canceling consumer"); - // Best-effort cancel; ignore errors during shutdown. - let _ = channel.basic_cancel("orchestrator", BasicCancelOptions { nowait: false }).await; - break; + // Best-effort cancel; do not wait for confirmation. + //let _ = tokio::time::timeout(Duration::from_millis(300), channel.basic_cancel("orchestrator", BasicCancelOptions { nowait: true })).await; + break 'consume; } maybe_delivery = consumer.next() => { match maybe_delivery { @@ -267,5 +267,21 @@ where } } + // Explicitly drop the consumer stream before closing channel/connection to stop background tasks + drop(consumer); + + // Close channel and connection to stop heartbeats and background tasks + match tokio::time::timeout(Duration::from_secs(2), channel.close(200, "shutdown")).await { + Ok(Ok(_)) => {}, + Ok(Err(e)) => warn!(error = %e, "failed to close AMQP channel"), + Err(_) => warn!("timeout while closing AMQP channel"), + } + match tokio::time::timeout(Duration::from_secs(2), conn.close(200, "shutdown")).await { + Ok(Ok(_)) => {}, + Ok(Err(e)) => warn!(error = %e, "failed to close AMQP connection"), + Err(_) => warn!("timeout while closing AMQP connection"), + } + + info!("consume_jobs completed"); Ok(()) } diff --git a/crates/orchestrator/src/main.rs b/crates/orchestrator/src/main.rs index 4876875..2540a0c 100644 --- a/crates/orchestrator/src/main.rs +++ b/crates/orchestrator/src/main.rs @@ -13,6 +13,7 @@ use config::OrchestratorConfig; use hypervisor::{RouterHypervisor, VmSpec, JobContext}; use scheduler::{Scheduler, SchedItem}; use std::sync::Arc; +use tokio::sync::Notify; use crate::persist::{Persist, JobState}; #[derive(Parser, Debug)] @@ -113,9 +114,11 @@ async fn main() -> Result<()> { persist.clone(), Duration::from_secs(opts.vm_placeholder_run_secs), ); + let sched_shutdown = Arc::new(Notify::new()); let sched_tx = sched.sender(); + let sched_shutdown_task = sched_shutdown.clone(); let scheduler_task = tokio::spawn(async move { - if let Err(e) = sched.run().await { + if let Err(e) = sched.run_with_shutdown(sched_shutdown_task).await { warn!(error = %e, "scheduler stopped with error"); } }); @@ -182,12 +185,14 @@ async fn main() -> Result<()> { // Signal consumer to stop accepting new jobs let _ = shutdown_tx.send(()); + // Ask scheduler to shut down cooperatively (interrupt placeholders) + sched_shutdown.notify_waiters(); + // Drop sender to let scheduler drain and exit drop(sched_tx); - // Wait for consumer and scheduler to finish - let _ = consumer_task.await; - let _ = scheduler_task.await; + // Wait for consumer and scheduler to finish concurrently + let (_c_res, _s_res) = tokio::join!(consumer_task, scheduler_task); Ok(()) } diff --git a/crates/orchestrator/src/scheduler.rs b/crates/orchestrator/src/scheduler.rs index 05caa53..ae4e0b2 100644 --- a/crates/orchestrator/src/scheduler.rs +++ b/crates/orchestrator/src/scheduler.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use dashmap::DashMap; use miette::Result; -use tokio::sync::{mpsc, Semaphore}; +use tokio::sync::{mpsc, Semaphore, Notify}; use tracing::{error, info, warn}; use crate::hypervisor::{Hypervisor, VmSpec, JobContext, BackendTag}; @@ -45,77 +45,97 @@ impl Scheduler { pub fn sender(&self) -> mpsc::Sender { self.tx.clone() } - pub async fn run(self) -> Result<()> { + pub async fn run_with_shutdown(self, shutdown: Arc) -> Result<()> { let Scheduler { hv, mut rx, global_sem, label_sems, persist, placeholder_runtime, .. } = self; let mut handles = Vec::new(); - while let Some(item) = rx.recv().await { - let hv = hv.clone(); - let global = global_sem.clone(); - let label_sems = label_sems.clone(); - let persist = persist.clone(); - let handle = tokio::spawn(async move { - // Acquire global and label permits (owned permits so they live inside the task) - let _g = match global.acquire_owned().await { - Ok(p) => p, - Err(_) => return, - }; - let label_key = item.spec.label.clone(); - let sem_arc = if let Some(entry) = label_sems.get(&label_key) { - entry.clone() - } else { - let s = Arc::new(Semaphore::new(1)); - label_sems.insert(label_key.clone(), s.clone()); - s - }; - let _l = match sem_arc.acquire_owned().await { - Ok(p) => p, - Err(_) => return, - }; + let mut shutting_down = false; + 'scheduler: loop { + tokio::select! { + // Only react once to shutdown and then focus on draining the channel + _ = shutdown.notified(), if !shutting_down => { + shutting_down = true; + info!("scheduler: shutdown requested; will drain queue and interrupt placeholders"); + break 'scheduler; + } + maybe_item = rx.recv() => { + match maybe_item { + Some(item) => { + let hv = hv.clone(); + let global = global_sem.clone(); + let label_sems = label_sems.clone(); + let persist = persist.clone(); + let shutdown = shutdown.clone(); + let handle = tokio::spawn(async move { + // Acquire global and label permits (owned permits so they live inside the task) + let _g = match global.acquire_owned().await { Ok(p) => p, Err(_) => return }; + let label_key = item.spec.label.clone(); + let sem_arc = if let Some(entry) = label_sems.get(&label_key) { entry.clone() } else { + let s = Arc::new(Semaphore::new(1)); + label_sems.insert(label_key.clone(), s.clone()); + s + }; + let _l = match sem_arc.acquire_owned().await { Ok(p) => p, Err(_) => return }; - // Provision and run - match hv.prepare(&item.spec, &item.ctx).await { - Ok(h) => { - // persist prepared VM state - let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str()); - let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str()); - let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") }; - if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await { - warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed"); + // Provision and run + match hv.prepare(&item.spec, &item.ctx).await { + Ok(h) => { + // persist prepared VM state + let overlay = h.overlay_path.as_ref().and_then(|p| p.to_str()); + let seed = h.seed_iso_path.as_ref().and_then(|p| p.to_str()); + let backend = match h.backend { BackendTag::Noop => Some("noop"), #[cfg(all(target_os = "linux", feature = "libvirt"))] BackendTag::Libvirt => Some("libvirt"), #[cfg(target_os = "illumos")] BackendTag::Zones => Some("zones") }; + if let Err(e) = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Prepared).await { + warn!(error = %e, request_id = %item.ctx.request_id, domain = %h.id, "persist prepare failed"); + } + if let Err(e) = hv.start(&h).await { + error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM"); + let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await; + return; + } + let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await; + let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await; + info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (workload execution placeholder)"); + // Placeholder job runtime (configurable), but end early on shutdown + tokio::select! { + _ = tokio::time::sleep(placeholder_runtime) => {}, + _ = shutdown.notified() => { + info!(request_id = %item.ctx.request_id, label = %label_key, "shutdown: ending placeholder early"); + } + } + // Stop and destroy + if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await { + error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM"); + } + let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await; + if let Err(e) = hv.destroy(h.clone()).await { + error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM"); + } + let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await; + let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await; + } + Err(e) => { + error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM"); + let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await; + return; + } + } + }); + handles.push(handle); } - if let Err(e) = hv.start(&h).await { - error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to start VM"); - let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await; - return; - } - let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Running).await; - let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Running).await; - info!(request_id = %item.ctx.request_id, label = %label_key, "vm started (workload execution placeholder)"); - // Placeholder job runtime (configurable) - tokio::time::sleep(placeholder_runtime).await; - // Stop and destroy - if let Err(e) = hv.stop(&h, Duration::from_secs(10)).await { - error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to stop VM"); - } - let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Stopped).await; - if let Err(e) = hv.destroy(h.clone()).await { - error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to destroy VM"); - } - let _ = persist.record_vm_event(item.ctx.request_id, &h.id, overlay, seed, backend, VmPersistState::Destroyed).await; - let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Succeeded).await; - } - Err(e) => { - error!(error = %e, request_id = %item.ctx.request_id, label = %label_key, "failed to prepare VM"); - let _ = persist.record_job_state(item.ctx.request_id, &item.ctx.repo_url, &item.ctx.commit_sha, Some(&item.spec.label), JobState::Failed).await; - return; + None => break, } } - }); - handles.push(handle); + } } // Wait for all in-flight tasks to finish for h in handles { let _ = h.await; } + info!("scheduler: completed"); Ok(()) } + + pub async fn run(self) -> Result<()> { + // Backward-compat: run until channel closed (no external shutdown) + self.run_with_shutdown(Arc::new(Notify::new())).await + } }