use std::{ borrow::ToOwned, fmt::Debug, hash::Hash, sync::{Arc, TryLockError::WouldBlock}, }; use tokio::sync::OwnedMutexGuard as Omg; use crate::{Result, err}; /// Map of Mutexes #[derive(Debug)] pub struct MutexMap { map: Map, } #[derive(Debug)] pub struct Guard { map: Map, val: Omg, } type Map = Arc>; type MapMutex = std::sync::Mutex>; type HashMap = std::collections::HashMap>; type Value = Arc>; impl MutexMap where Key: Clone + Eq + Hash + Send, Val: Default + Send, { #[must_use] pub fn new() -> Self { Self { map: Map::new(MapMutex::new(HashMap::new())), } } #[tracing::instrument(level = "trace", skip(self))] pub async fn lock(&self, k: &K) -> Guard where K: Debug + Send + ?Sized + Sync + ToOwned, { let val = self .map .lock() .expect("locked") .entry(k.to_owned()) .or_default() .clone(); Guard:: { map: Arc::clone(&self.map), val: val.lock_owned().await, } } #[tracing::instrument(level = "trace", skip(self))] pub fn try_lock(&self, k: &K) -> Result> where K: Debug + Send + ?Sized + Sync + ToOwned, { let val = self .map .lock() .expect("locked") .entry(k.to_owned()) .or_default() .clone(); Ok(Guard:: { map: Arc::clone(&self.map), val: val .try_lock_owned() .map_err(|_| err!("would yield"))?, }) } #[tracing::instrument(level = "trace", skip(self))] pub fn try_try_lock(&self, k: &K) -> Result> where K: Debug + Send + ?Sized + Sync + ToOwned, { let val = self .map .try_lock() .map_err(|e| match e { | WouldBlock => err!("would block"), | _ => panic!("{e:?}"), })? .entry(k.to_owned()) .or_default() .clone(); Ok(Guard:: { map: Arc::clone(&self.map), val: val .try_lock_owned() .map_err(|_| err!("would yield"))?, }) } #[must_use] pub fn contains(&self, k: &Key) -> bool { self.map.lock().expect("locked").contains_key(k) } #[must_use] pub fn is_empty(&self) -> bool { self.map.lock().expect("locked").is_empty() } #[must_use] pub fn len(&self) -> usize { self.map.lock().expect("locked").len() } } impl Default for MutexMap where Key: Clone + Eq + Hash + Send, Val: Default + Send, { fn default() -> Self { Self::new() } } impl Drop for Guard { #[tracing::instrument(name = "unlock", level = "trace", skip_all)] fn drop(&mut self) { if Arc::strong_count(Omg::mutex(&self.val)) <= 2 { self.map.lock().expect("locked").retain(|_, val| { !Arc::ptr_eq(val, Omg::mutex(&self.val)) || Arc::strong_count(val) > 2 }); } } }