diff --git a/src/core/utils/two_phase_counter.rs b/src/core/utils/two_phase_counter.rs index 3573d377..51888000 100644 --- a/src/core/utils/two_phase_counter.rs +++ b/src/core/utils/two_phase_counter.rs @@ -44,6 +44,10 @@ pub struct State Result + Sync> { /// this list is the "retirement" sequence number where all writes have /// completed and all reads are globally visible. pending: VecDeque, + + /// Callback to notify updates of the retirement value. This is likely + /// called from the destructor of a permit/guard; try not to panic. + release: F, } #[derive(Debug)] @@ -63,8 +67,10 @@ impl Result + Sync> Counter { /// Construct a new Two-Phase counter state. The value of `init` is /// considered retired, and the next sequence number dispatched will be one /// greater. - pub fn new(init: u64, commit: F) -> Arc { - Arc::new(Self { inner: State::new(init, commit).into() }) + pub fn new(init: u64, commit: F, release: F) -> Arc { + Arc::new(Self { + inner: State::new(init, commit, release).into(), + }) } /// Obtain a sequence number to conduct write operations for the scope. @@ -83,16 +89,27 @@ impl Result + Sync> Counter { .expect("locked for reading") .retired() } + + /// Load the highest sequence number (dispatched); may still be pending or + /// may be retired. + #[inline] + pub fn dispatched(&self) -> u64 { + self.inner + .read() + .expect("locked for reading") + .dispatched + } } impl Result + Sync> State { /// Create new state, starting from `init`. The next sequence number /// dispatched will be one greater than `init`. - fn new(dispatched: u64, commit: F) -> Self { + fn new(dispatched: u64, commit: F, release: F) -> Self { Self { dispatched, commit, pending: VecDeque::new(), + release, } } @@ -127,6 +144,10 @@ impl Result + Sync> State { .expect("sequence number at index must be removed"); debug_assert!(removed == id, "sequence number removed must match id"); + + if index == 0 { + (self.release)(id).expect("release callback should not error"); + } } /// Calculate the retired sequence number, one less than the lowest pending diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index ea64fa03..c32bc50f 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,13 +1,15 @@ use std::sync::Arc; +use tokio::sync::{watch, watch::Sender}; use tuwunel_core::{ - Result, utils, + Result, err, utils, utils::two_phase_counter::{Counter as TwoPhaseCounter, Permit as TwoPhasePermit}, }; use tuwunel_database::{Database, Deserialized, Map}; pub struct Data { global: Arc, + retires: Sender, counter: Arc, pub(super) db: Arc, } @@ -21,16 +23,39 @@ const COUNTER: &[u8] = b"c"; impl Data { pub(super) fn new(args: &crate::Args<'_>) -> Self { let db = args.db.clone(); + let count = Self::stored_count(&args.db["global"]).expect("initialize global counter"); + let retires = watch::channel(count).0; Self { db: args.db.clone(), global: args.db["global"].clone(), + retires: retires.clone(), counter: Counter::new( - Self::stored_count(&args.db["global"]).expect("initialized global counter"), + count, Box::new(move |count| Self::store_count(&db, &db["global"], count)), + Box::new(move |count| Self::handle_retire(&retires, count)), ), } } + pub async fn wait_pending(&self) -> Result { + let count = self.counter.dispatched(); + self.wait_count(&count).await.inspect(|retired| { + debug_assert!( + *retired >= count, + "Expecting retired sequence number >= snapshotted dispatch number" + ); + }) + } + + pub async fn wait_count(&self, count: &u64) -> Result { + self.retires + .subscribe() + .wait_for(|retired| retired.ge(count)) + .await + .map(|retired| *retired) + .map_err(|e| err!("counter channel error {e:?}")) + } + #[inline] pub fn next_count(&self) -> Permit { self.counter @@ -41,6 +66,12 @@ impl Data { #[inline] pub fn current_count(&self) -> u64 { self.counter.current() } + fn handle_retire(sender: &Sender, count: u64) -> Result { + let _prev = sender.send_replace(count); + + Ok(()) + } + fn store_count(db: &Arc, global: &Arc, count: u64) -> Result { let _cork = db.cork(); global.insert(COUNTER, count.to_be_bytes()); diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 29ecc0ab..5d96325b 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -102,6 +102,12 @@ impl crate::Service for Service { } impl Service { + #[inline] + pub async fn wait_pending(&self) -> Result { self.db.wait_pending().await } + + #[inline] + pub async fn wait_count(&self, count: &u64) -> Result { self.db.wait_count(count).await } + #[inline] #[must_use] pub fn next_count(&self) -> data::Permit { self.db.next_count() }