diff --git a/mill-io/src/thread_pool.rs b/mill-io/src/thread_pool.rs index 4b4acb2..54fe5b9 100644 --- a/mill-io/src/thread_pool.rs +++ b/mill-io/src/thread_pool.rs @@ -331,25 +331,32 @@ impl ComputeThreadPool { sequence, }; - self.metrics.tasks_submitted.fetch_add(1, Ordering::Relaxed); - match priority { - TaskPriority::Low => self.metrics.queue_depth_low.fetch_add(1, Ordering::Relaxed), - TaskPriority::Normal => self - .metrics - .queue_depth_normal - .fetch_add(1, Ordering::Relaxed), - TaskPriority::High => self - .metrics - .queue_depth_high - .fetch_add(1, Ordering::Relaxed), - TaskPriority::Critical => self - .metrics - .queue_depth_critical - .fetch_add(1, Ordering::Relaxed), - }; + { + let mut queue = self.state.queue.lock(); + queue.push(priority_task); + + // All metric updates inside the lock so a reader never sees + // tasks_submitted > sum(queue_depths) + completed + failed. + self.metrics.tasks_submitted.fetch_add(1, Ordering::Relaxed); + match priority { + TaskPriority::Low => self.metrics.queue_depth_low.fetch_add(1, Ordering::Relaxed), + TaskPriority::Normal => self + .metrics + .queue_depth_normal + .fetch_add(1, Ordering::Relaxed), + TaskPriority::High => self + .metrics + .queue_depth_high + .fetch_add(1, Ordering::Relaxed), + TaskPriority::Critical => self + .metrics + .queue_depth_critical + .fetch_add(1, Ordering::Relaxed), + }; + } - let mut queue = self.state.queue.lock(); - queue.push(priority_task); + // Notify outside the lock so the woken worker doesn't immediately + // block trying to acquire the lock we still hold. self.state.condvar.notify_one(); } @@ -433,60 +440,307 @@ mod tests { assert_eq!(counter.load(Ordering::SeqCst), 1); } - #[test] - fn test_compute_pool_priority() { - let pool = ComputeThreadPool::new(1); // Single thread to ensure order execution + /// Helper: spawn a blocking task that occupies the single worker, queue + /// several tasks with different priorities while it's blocked, then + /// release and verify execution order. + fn assert_priority_order( + pool: &ComputeThreadPool, + queued_priorities: &[TaskPriority], + expected_order: &[usize], + ) { let result = Arc::new(Mutex::new(Vec::new())); - // use a barrier to ensure the first task is running and blocking the worker + // Barrier ensures the blocker task has started before we queue more. let barrier = Arc::new(Barrier::new(2)); - let b_clone = barrier.clone(); + let b = barrier.clone(); + + // Second barrier: the blocker waits here until we've queued everything. + let release = Arc::new(Barrier::new(2)); + let rel = release.clone(); - let r1 = result.clone(); pool.spawn( move || { - b_clone.wait(); // signal that we started - std::thread::sleep(Duration::from_millis(50)); // block worker - r1.lock().unwrap().push(1); + b.wait(); // signal "I'm running" + rel.wait(); // wait until main says "go" }, TaskPriority::Low, ); - // wait for Task 1 to start + // Worker is now running the blocker. barrier.wait(); - // these should be queued while the first one runs - let r2 = result.clone(); + // Queue tasks while the worker is blocked. + for (i, &pri) in queued_priorities.iter().enumerate() { + let r = result.clone(); + pool.spawn(move || r.lock().unwrap().push(i), pri); + } + + // Release the blocker so the worker drains the queue. + release.wait(); + + // Wait for all queued tasks to complete. + let start = std::time::Instant::now(); + loop { + if result.lock().unwrap().len() == queued_priorities.len() { + break; + } + assert!( + start.elapsed() < Duration::from_secs(2), + "Timed out waiting for tasks" + ); + std::thread::sleep(Duration::from_millis(1)); + } + + let res = result.lock().unwrap(); + assert_eq!(*res, expected_order); + } + + #[test] + fn test_compute_pool_priority_high_before_low() { + let pool = ComputeThreadPool::new(1); + // Queue: Low, High, Normal. Expected dequeue: High, Normal, Low. + assert_priority_order( + &pool, + &[TaskPriority::Low, TaskPriority::High, TaskPriority::Normal], + &[1, 2, 0], + ); + } + + #[test] + fn test_compute_pool_priority_critical_first() { + let pool = ComputeThreadPool::new(1); + // Critical should run before everything else. + assert_priority_order( + &pool, + &[ + TaskPriority::Normal, + TaskPriority::Low, + TaskPriority::Critical, + TaskPriority::High, + ], + &[2, 3, 0, 1], + ); + } + + #[test] + fn test_compute_pool_fifo_within_same_priority() { + let pool = ComputeThreadPool::new(1); + // All Normal: should execute in submission order (FIFO). + assert_priority_order( + &pool, + &[ + TaskPriority::Normal, + TaskPriority::Normal, + TaskPriority::Normal, + ], + &[0, 1, 2], + ); + } + + #[test] + fn test_compute_pool_all_levels_ordered() { + let pool = ComputeThreadPool::new(1); + // One of each, submitted Low -> Normal -> High -> Critical. + // Should dequeue Critical, High, Normal, Low. + assert_priority_order( + &pool, + &[ + TaskPriority::Low, + TaskPriority::Normal, + TaskPriority::High, + TaskPriority::Critical, + ], + &[3, 2, 1, 0], + ); + } + + #[test] + fn test_compute_pool_panic_is_caught() { + let pool = ComputeThreadPool::new(1); + let metrics = pool.metrics(); + + let done = Arc::new(Barrier::new(2)); + let d = done.clone(); + + // First task panics. + pool.spawn(move || panic!("intentional panic"), TaskPriority::Normal); + + // Second task runs after the panic to prove the worker survived. pool.spawn( move || { - r2.lock().unwrap().push(2); + d.wait(); }, - TaskPriority::Low, + TaskPriority::Normal, ); - let r3 = result.clone(); + done.wait(); + + assert_eq!(metrics.tasks_failed(), 1); + assert_eq!(metrics.tasks_completed(), 1); + } + + #[test] + fn test_compute_pool_fifo_with_mixed_duplicates() { + let pool = ComputeThreadPool::new(1); + // Two Highs and two Lows interleaved. + // Should dequeue: High(0), High(2), Low(1), Low(3) (priority then FIFO). + assert_priority_order( + &pool, + &[ + TaskPriority::High, + TaskPriority::Low, + TaskPriority::High, + TaskPriority::Low, + ], + &[0, 2, 1, 3], + ); + } + + #[test] + fn test_compute_pool_panic_preserves_priority_order() { + let pool = ComputeThreadPool::new(1); + let result = Arc::new(Mutex::new(Vec::new())); + + let barrier = Arc::new(Barrier::new(2)); + let b = barrier.clone(); + let release = Arc::new(Barrier::new(2)); + let rel = release.clone(); + + // Blocker task. pool.spawn( move || { - r3.lock().unwrap().push(3); + b.wait(); + rel.wait(); }, - TaskPriority::High, + TaskPriority::Low, ); + barrier.wait(); + + // Queue: panic(High), then Normal, then Low. + pool.spawn(move || panic!("boom"), TaskPriority::High); - let r4 = result.clone(); + let r1 = result.clone(); pool.spawn( - move || { - r4.lock().unwrap().push(4); - }, + move || r1.lock().unwrap().push("normal"), TaskPriority::Normal, ); - // wait for tasks to finish - std::thread::sleep(Duration::from_millis(200)); + let r2 = result.clone(); + pool.spawn(move || r2.lock().unwrap().push("low"), TaskPriority::Low); + + release.wait(); + + let start = std::time::Instant::now(); + loop { + if result.lock().unwrap().len() == 2 { + break; + } + assert!( + start.elapsed() < Duration::from_secs(2), + "Timed out waiting for tasks" + ); + std::thread::sleep(Duration::from_millis(1)); + } + + // The panic task ran first (High), then Normal, then Low. + let res = result.lock().unwrap(); + assert_eq!(*res, vec!["normal", "low"]); + } + + #[test] + fn test_compute_pool_shutdown_drains_queue() { + let result = Arc::new(Mutex::new(Vec::new())); + + { + let pool = ComputeThreadPool::new(1); + + let barrier = Arc::new(Barrier::new(2)); + let b = barrier.clone(); + let release = Arc::new(Barrier::new(2)); + let rel = release.clone(); + + pool.spawn( + move || { + b.wait(); + rel.wait(); + }, + TaskPriority::Low, + ); + barrier.wait(); + + // Queue tasks while worker is blocked. + for i in 0..5 { + let r = result.clone(); + pool.spawn(move || r.lock().unwrap().push(i), TaskPriority::Normal); + } + + // Release the blocker, then drop the pool. + // Drop should wait for all queued tasks to complete. + release.wait(); + } let res = result.lock().unwrap(); - // 1 runs first (started immediately). - // Then 3 (High), 4 (Normal), 2 (Low). - assert_eq!(*res, vec![1, 3, 4, 2]); + assert_eq!(res.len(), 5); + // All Normal, so FIFO order. + assert_eq!(*res, vec![0, 1, 2, 3, 4]); + } + + #[test] + fn test_compute_pool_multi_worker_picks_highest_first() { + // 2 workers. Block both, then queue Critical + Low. + // When released, the first worker to pop should get Critical. + let pool = ComputeThreadPool::new(2); + let order = Arc::new(Mutex::new(Vec::new())); + + let barrier = Arc::new(Barrier::new(3)); // 2 workers + main + let b1 = barrier.clone(); + let b2 = barrier.clone(); + let release = Arc::new(Barrier::new(3)); + let rel1 = release.clone(); + let rel2 = release.clone(); + + pool.spawn( + move || { + b1.wait(); + rel1.wait(); + }, + TaskPriority::Low, + ); + pool.spawn( + move || { + b2.wait(); + rel2.wait(); + }, + TaskPriority::Low, + ); + barrier.wait(); + + // Both workers are blocked. Queue one Critical and one Low. + let o1 = order.clone(); + pool.spawn( + move || o1.lock().unwrap().push("critical"), + TaskPriority::Critical, + ); + let o2 = order.clone(); + pool.spawn(move || o2.lock().unwrap().push("low"), TaskPriority::Low); + + release.wait(); + + let start = std::time::Instant::now(); + loop { + if order.lock().unwrap().len() == 2 { + break; + } + assert!( + start.elapsed() < Duration::from_secs(2), + "Timed out waiting for tasks" + ); + std::thread::sleep(Duration::from_millis(1)); + } + + // Critical must be the first one popped. + let res = order.lock().unwrap(); + assert_eq!(res[0], "critical"); } #[test] @@ -494,61 +748,57 @@ mod tests { let pool = ComputeThreadPool::new(2); let metrics = pool.metrics(); - let barrier = Arc::new(Barrier::new(3)); // 2 workers + main thread - let barrier_clone = barrier.clone(); + // Barrier: 2 workers + main thread. + let barrier = Arc::new(Barrier::new(3)); + let b1 = barrier.clone(); + let b2 = barrier.clone(); - // Task 1: Occupy worker 1 + // Occupy both workers so subsequent tasks stay queued. pool.spawn( move || { - barrier_clone.wait(); // wait for main thread to check metrics + b1.wait(); }, TaskPriority::Normal, ); - - let barrier_clone2 = barrier.clone(); - // Task 2: Occupy worker 2 pool.spawn( move || { - barrier_clone2.wait(); // wait for main thread to check metrics + b2.wait(); }, TaskPriority::Normal, ); - // wait a bit for workers to pick up tasks + // Give workers time to pick up tasks and increment active_workers. std::thread::sleep(Duration::from_millis(50)); - // Task 3: Queue (Low) pool.spawn(|| {}, TaskPriority::Low); - - // Task 4: Queue (High) pool.spawn(|| {}, TaskPriority::High); + pool.spawn(|| {}, TaskPriority::Critical); - // check intermediate metrics - assert_eq!(metrics.tasks_submitted(), 4); - // both workers should be busy + assert_eq!(metrics.tasks_submitted(), 5); assert_eq!(metrics.active_workers(), 2); - // queued tasks assert_eq!(metrics.queue_depth_low(), 1); assert_eq!(metrics.queue_depth_high(), 1); - // running tasks are popped, so normal queue depth is 0 + assert_eq!(metrics.queue_depth_critical(), 1); + // The two Normal tasks were already popped, so depth is 0. assert_eq!(metrics.queue_depth_normal(), 0); + // Release workers. barrier.wait(); - // wait for completion let start = std::time::Instant::now(); - while metrics.tasks_completed() < 4 { - if start.elapsed() > Duration::from_secs(2) { - panic!("Timed out waiting for tasks to complete"); - } + while metrics.tasks_completed() < 5 { + assert!( + start.elapsed() < Duration::from_secs(2), + "Timed out waiting for tasks to complete" + ); std::thread::sleep(Duration::from_millis(10)); } - // check final metrics - assert_eq!(metrics.tasks_completed(), 4); + assert_eq!(metrics.tasks_completed(), 5); assert_eq!(metrics.active_workers(), 0); assert_eq!(metrics.queue_depth_low(), 0); assert_eq!(metrics.queue_depth_high(), 0); + assert_eq!(metrics.queue_depth_critical(), 0); assert!(metrics.total_execution_time_ns() > 0); } }