Refactor counter increment sites for TwoPhaseCounter.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,56 +1,51 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::Arc;
|
||||
|
||||
use tuwunel_core::{Result, utils};
|
||||
use tuwunel_core::{
|
||||
Result, utils,
|
||||
utils::two_phase_counter::{Counter as TwoPhaseCounter, Permit as TwoPhasePermit},
|
||||
};
|
||||
use tuwunel_database::{Database, Deserialized, Map};
|
||||
|
||||
pub struct Data {
|
||||
global: Arc<Map>,
|
||||
counter: RwLock<u64>,
|
||||
counter: Arc<Counter>,
|
||||
pub(super) db: Arc<Database>,
|
||||
}
|
||||
|
||||
pub(super) type Permit = TwoPhasePermit<Callback>;
|
||||
type Counter = TwoPhaseCounter<Callback>;
|
||||
type Callback = Box<dyn Fn(u64) -> Result + Send + Sync>;
|
||||
|
||||
const COUNTER: &[u8] = b"c";
|
||||
|
||||
impl Data {
|
||||
pub(super) fn new(args: &crate::Args<'_>) -> Self {
|
||||
let db = &args.db;
|
||||
let db = args.db.clone();
|
||||
Self {
|
||||
global: db["global"].clone(),
|
||||
counter: RwLock::new(
|
||||
Self::stored_count(&db["global"]).expect("initialized global counter"),
|
||||
),
|
||||
db: args.db.clone(),
|
||||
global: args.db["global"].clone(),
|
||||
counter: Counter::new(
|
||||
Self::stored_count(&args.db["global"]).expect("initialized global counter"),
|
||||
Box::new(move |count| Self::store_count(&db, &db["global"], count)),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
let _cork = self.db.cork();
|
||||
let mut lock = self.counter.write().expect("locked");
|
||||
let counter: &mut u64 = &mut lock;
|
||||
debug_assert!(
|
||||
*counter == Self::stored_count(&self.global).expect("database failure"),
|
||||
"counter mismatch"
|
||||
);
|
||||
|
||||
*counter = counter
|
||||
.checked_add(1)
|
||||
.expect("counter must not overflow u64");
|
||||
|
||||
self.global.insert(COUNTER, counter.to_be_bytes());
|
||||
|
||||
Ok(*counter)
|
||||
#[inline]
|
||||
pub fn next_count(&self) -> Permit {
|
||||
self.counter
|
||||
.next()
|
||||
.expect("failed to obtain next sequence number")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn current_count(&self) -> u64 {
|
||||
let lock = self.counter.read().expect("locked");
|
||||
let counter: &u64 = &lock;
|
||||
debug_assert!(
|
||||
*counter == Self::stored_count(&self.global).expect("database failure"),
|
||||
"counter mismatch"
|
||||
);
|
||||
pub fn current_count(&self) -> u64 { self.counter.current() }
|
||||
|
||||
*counter
|
||||
fn store_count(db: &Arc<Database>, global: &Arc<Map>, count: u64) -> Result {
|
||||
let _cork = db.cork();
|
||||
global.insert(COUNTER, count.to_be_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn stored_count(global: &Arc<Map>) -> Result<u64> {
|
||||
@@ -59,6 +54,12 @@ impl Data {
|
||||
.as_deref()
|
||||
.map_or(Ok(0_u64), utils::u64_from_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Data {
|
||||
pub fn bump_database_version(&self, new_version: u64) {
|
||||
self.global.raw_put(b"version", new_version);
|
||||
}
|
||||
|
||||
pub async fn database_version(&self) -> u64 {
|
||||
self.global
|
||||
@@ -67,9 +68,4 @@ impl Data {
|
||||
.deserialized()
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn bump_database_version(&self, new_version: u64) {
|
||||
self.global.raw_put(b"version", new_version);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user