diff --git a/src/lib.rs b/src/lib.rs index 736d9bf..ffff76b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -350,6 +350,24 @@ pub trait Watcher: Clone { } } + fn filter(mut self, filter: impl Fn(&T) -> bool + Send + Sync + 'static) -> Filter + where + T: Clone + Eq, + Self: Watcher, + { + let current = self.get(); + let current = if filter(¤t) { + Some(current) + } else { + None + }; + Filter { + current, + filter: Arc::new(filter), + watcher: self, + } + } + /// Returns a watcher that updates every time this or the other watcher /// updates, and yields both watcher's items together when that happens. fn or(self, other: W) -> (Self, W) { @@ -519,6 +537,57 @@ impl Watcher for Map { } } +/// Wraps a [`Watcher`] to allow observing a derived value. +/// +/// See [`Watcher::map`]. +#[derive(derive_more::Debug, Clone)] +pub struct Filter +where + T: Clone + Eq, + W: Watcher, +{ + #[debug("Arc bool + 'static>")] + filter: Arc bool + Send + Sync + 'static>, + watcher: W, + current: Option, +} + +impl Watcher for Filter +where + T: Clone + Eq, + W: Watcher, +{ + type Value = Option; + + fn get(&mut self) -> Self::Value { + self.current.clone() + } + + fn is_connected(&self) -> bool { + self.watcher.is_connected() + } + + fn poll_updated( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + loop { + let value = ready!(self.watcher.poll_updated(cx)?); + let filtered = if (self.filter)(&value) { + Some(value) + } else { + None + }; + if filtered != self.current { + self.current = filtered.clone(); + return Poll::Ready(Ok(filtered)); + } else { + self.current = filtered; + } + } + } +} + /// Future returning the next item after the current one in a [`Watcher`]. /// /// See [`Watcher::updated`]. @@ -1005,4 +1074,50 @@ mod tests { assert!(!a.has_watchers()); assert!(!b.has_watchers()); } + + #[tokio::test] + async fn test_filter_basic() { + let a = Watchable::new(1u8); + let mut filtered = a.watch().filter(|x| *x > 2 && *x < 6); + + assert_eq!(filtered.get(), None); + + let handle = tokio::task::spawn(async move { filtered.stream().collect::>().await }); + + for i in 2u8..10 { + a.set(i).unwrap(); + tokio::task::yield_now().await; + } + drop(a); + + let values = tokio::time::timeout(Duration::from_secs(5), handle) + .await + .unwrap() + .unwrap(); + + assert_eq!(values, vec![None, Some(3u8), Some(4), Some(5), None]); + } + + #[tokio::test] + async fn test_filter_init() { + let a = Watchable::new(1u8); + let mut filtered = a.watch().filter(|x| *x > 2 && *x < 6); + + assert_eq!(filtered.get(), None); + + let handle = tokio::task::spawn(async move { filtered.initialized().await }); + + for i in 2u8..10 { + a.set(i).unwrap(); + tokio::task::yield_now().await; + } + drop(a); + + let value = tokio::time::timeout(Duration::from_secs(5), handle) + .await + .unwrap() + .unwrap(); + + assert_eq!(value, 3); + } }