#![warn(rust_2018_idioms)] #![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))] use std::collections::HashSet; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; use tokio::runtime::Builder; const TASKS: usize = 8; const ITERATIONS: usize = 64; /// Assert that the spawn task hook always fires when set. #[test] fn spawn_task_hook_fires() { let count = Arc::new(AtomicUsize::new(0)); let count2 = Arc::clone(&count); let ids = Arc::new(Mutex::new(HashSet::new())); let ids2 = Arc::clone(&ids); let runtime = Builder::new_current_thread() .on_task_spawn(move |data| { ids2.lock().unwrap().insert(data.id()); count2.fetch_add(1, Ordering::SeqCst); }) .build() .unwrap(); for _ in 0..TASKS { runtime.spawn(std::future::pending::<()>()); } let count_realized = count.load(Ordering::SeqCst); assert_eq!( TASKS, count_realized, "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {count_realized}" ); let count_ids_realized = ids.lock().unwrap().len(); assert_eq!( TASKS, count_ids_realized, "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {count_realized}" ); } /// Assert that the terminate task hook always fires when set. #[test] fn terminate_task_hook_fires() { let count = Arc::new(AtomicUsize::new(0)); let count2 = Arc::clone(&count); let runtime = Builder::new_current_thread() .on_task_terminate(move |_data| { count2.fetch_add(1, Ordering::SeqCst); }) .build() .unwrap(); for _ in 0..TASKS { runtime.spawn(std::future::ready(())); } runtime.block_on(async { // tick the runtime a bunch to close out tasks for _ in 0..ITERATIONS { tokio::task::yield_now().await; } }); assert_eq!(TASKS, count.load(Ordering::SeqCst)); } /// Test that the correct spawn location is provided to the task hooks on a /// current thread runtime. #[test] fn task_hook_spawn_location_current_thread() { let spawns = Arc::new(AtomicUsize::new(0)); let poll_starts = Arc::new(AtomicUsize::new(0)); let poll_ends = Arc::new(AtomicUsize::new(0)); let runtime = Builder::new_current_thread() .on_task_spawn(mk_spawn_location_hook( "(current_thread) on_task_spawn", &spawns, )) .on_before_task_poll(mk_spawn_location_hook( "(current_thread) on_before_task_poll", &poll_starts, )) .on_after_task_poll(mk_spawn_location_hook( "(current_thread) on_after_task_poll", &poll_ends, )) .build() .unwrap(); let task = runtime.spawn(async move { tokio::task::yield_now().await }); runtime.block_on(async move { // Spawn tasks using both `runtime.spawn(...)` and `tokio::spawn(...)` // to ensure the correct location is captured in both code paths. task.await.unwrap(); tokio::spawn(async move {}).await.unwrap(); // tick the runtime a bunch to close out tasks for _ in 0..ITERATIONS { tokio::task::yield_now().await; } }); assert_eq!(spawns.load(Ordering::SeqCst), 2); let poll_starts = poll_starts.load(Ordering::SeqCst); assert!(poll_starts > 2); assert_eq!(poll_starts, poll_ends.load(Ordering::SeqCst)); } /// Test that the correct spawn location is provided to the task hooks on a /// multi-thread runtime. /// /// Testing this separately is necessary as the spawn code paths are different /// and we should ensure that `#[track_caller]` is passed through correctly /// for both runtimes. #[cfg_attr( target_os = "wasi", ignore = "WASI does not support multi-threaded runtime" )] #[test] fn task_hook_spawn_location_multi_thread() { let spawns = Arc::new(AtomicUsize::new(0)); let poll_starts = Arc::new(AtomicUsize::new(0)); let poll_ends = Arc::new(AtomicUsize::new(0)); let runtime = Builder::new_multi_thread() .on_task_spawn(mk_spawn_location_hook( "(multi_thread) on_task_spawn", &spawns, )) .on_before_task_poll(mk_spawn_location_hook( "(multi_thread) on_before_task_poll", &poll_starts, )) .on_after_task_poll(mk_spawn_location_hook( "(multi_thread) on_after_task_poll", &poll_ends, )) .build() .unwrap(); let task = runtime.spawn(async move { tokio::task::yield_now().await }); runtime.block_on(async move { // Spawn tasks using both `runtime.spawn(...)` and `tokio::spawn(...)` // to ensure the correct location is captured in both code paths. task.await.unwrap(); tokio::spawn(async move {}).await.unwrap(); // tick the runtime a bunch to close out tasks for _ in 0..ITERATIONS { tokio::task::yield_now().await; } }); // Give the runtime to shut down so that we see all the expected calls to // the task hooks. runtime.shutdown_timeout(std::time::Duration::from_secs(60)); // Note: we "read" the counters using `fetch_add(0, SeqCst)` rather than // `load(SeqCst)` because read-write-modify operations are guaranteed to // observe the latest value, while the load is not. // This avoids a race that may cause test flakiness. assert_eq!(spawns.fetch_add(0, Ordering::SeqCst), 2); let poll_starts = poll_starts.fetch_add(0, Ordering::SeqCst); assert!(poll_starts > 2); assert_eq!(poll_starts, poll_ends.fetch_add(0, Ordering::SeqCst)); } fn mk_spawn_location_hook( event: &'static str, count: &Arc, ) -> impl Fn(&tokio::runtime::TaskMeta<'_>) { let count = Arc::clone(count); move |data| { eprintln!("{event} ({:?}): {:?}", data.id(), data.spawned_at()); // Assert that the spawn location is in this file. // Don't make assertions about line number/column here, as these // may change as new code is added to the test file... assert_eq!( data.spawned_at().file(), file!(), "incorrect spawn location in {event} hook", ); count.fetch_add(1, Ordering::SeqCst); } }