From f5822a02329eacdc4fa756f422c274dda25cc127 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Thu, 13 Nov 2025 13:45:46 +0100 Subject: [PATCH 1/2] Introduce `InMemoryStore` for testing Recently, `rust-lightning` broke the (async) API of the `TestStore`, making it ~impossible to use in regular tests. Here, we un-DRY our `TestStore` implementation and simply copy over the previous `TestStore` version, now named `InMemoryStore` to discern the objects. We also switch all feasible instances over to use `InMemoryStore` rather than LDK's `test_utils::TestStore`. --- src/data_store.rs | 5 +- src/event.rs | 7 +- src/io/test_utils.rs | 128 +++++++++++++++++- .../asynchronous/static_invoice_store.rs | 4 +- src/peer_store.rs | 5 +- 5 files changed, 138 insertions(+), 11 deletions(-) diff --git a/src/data_store.rs b/src/data_store.rs index 83cbf4476..87bd831c9 100644 --- a/src/data_store.rs +++ b/src/data_store.rs @@ -172,10 +172,11 @@ where #[cfg(test)] mod tests { use lightning::impl_writeable_tlv_based; - use lightning::util::test_utils::{TestLogger, TestStore}; + use lightning::util::test_utils::TestLogger; use super::*; use crate::hex_utils; + use crate::io::test_utils::InMemoryStore; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] struct TestObjectId { @@ -234,7 +235,7 @@ mod tests { #[test] fn data_is_persisted() { - let store: Arc = Arc::new(TestStore::new(false)); + let store: Arc = Arc::new(InMemoryStore::new()); let logger = Arc::new(TestLogger::new()); let primary_namespace = "datastore_test_primary".to_string(); let secondary_namespace = "datastore_test_secondary".to_string(); diff --git a/src/event.rs b/src/event.rs index 1946350a3..42b60e213 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1605,13 +1605,14 @@ mod tests { use std::sync::atomic::{AtomicU16, Ordering}; use std::time::Duration; - use lightning::util::test_utils::{TestLogger, TestStore}; + use lightning::util::test_utils::TestLogger; use super::*; + use crate::io::test_utils::InMemoryStore; #[tokio::test] async fn event_queue_persistence() { - let store: Arc = Arc::new(TestStore::new(false)); + let store: Arc = Arc::new(InMemoryStore::new()); let logger = Arc::new(TestLogger::new()); let event_queue = Arc::new(EventQueue::new(Arc::clone(&store), Arc::clone(&logger))); assert_eq!(event_queue.next_event(), None); @@ -1647,7 +1648,7 @@ mod tests { #[tokio::test] async fn event_queue_concurrency() { - let store: Arc = Arc::new(TestStore::new(false)); + let store: Arc = Arc::new(InMemoryStore::new()); let logger = Arc::new(TestLogger::new()); let event_queue = Arc::new(EventQueue::new(Arc::clone(&store), Arc::clone(&logger))); assert_eq!(event_queue.next_event(), None); diff --git a/src/io/test_utils.rs b/src/io/test_utils.rs index fd4de1c9f..310638dd8 100644 --- a/src/io/test_utils.rs +++ b/src/io/test_utils.rs @@ -5,8 +5,13 @@ // http://opensource.org/licenses/MIT>, at your option. You may not use this file except in // accordance with one or both of these licenses. +use std::boxed::Box; +use std::collections::{hash_map, HashMap}; +use std::future::Future; use std::panic::RefUnwindSafe; use std::path::PathBuf; +use std::pin::Pin; +use std::sync::Mutex; use lightning::events::ClosureReason; use lightning::ln::functional_test_utils::{ @@ -14,10 +19,10 @@ use lightning::ln::functional_test_utils::{ create_network, create_node_cfgs, create_node_chanmgrs, send_payment, TestChanMonCfg, }; use lightning::util::persist::{ - KVStoreSync, MonitorUpdatingPersister, KVSTORE_NAMESPACE_KEY_MAX_LEN, + KVStore, KVStoreSync, MonitorUpdatingPersister, KVSTORE_NAMESPACE_KEY_MAX_LEN, }; use lightning::util::test_utils; -use lightning::{check_added_monitors, check_closed_broadcast, check_closed_event}; +use lightning::{check_added_monitors, check_closed_broadcast, check_closed_event, io}; use rand::distr::Alphanumeric; use rand::{rng, Rng}; @@ -32,6 +37,125 @@ type TestMonitorUpdatePersister<'a, K> = MonitorUpdatingPersister< const EXPECTED_UPDATES_PER_PAYMENT: u64 = 5; +pub struct InMemoryStore { + persisted_bytes: Mutex>>>, +} + +impl InMemoryStore { + pub fn new() -> Self { + let persisted_bytes = Mutex::new(HashMap::new()); + Self { persisted_bytes } + } + + fn read_internal( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + let persisted_lock = self.persisted_bytes.lock().unwrap(); + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); + + if let Some(outer_ref) = persisted_lock.get(&prefixed) { + if let Some(inner_ref) = outer_ref.get(key) { + let bytes = inner_ref.clone(); + Ok(bytes) + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "Key not found")) + } + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "Namespace not found")) + } + } + + fn write_internal( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); + let outer_e = persisted_lock.entry(prefixed).or_insert(HashMap::new()); + outer_e.insert(key.to_string(), buf); + Ok(()) + } + + fn remove_internal( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, _lazy: bool, + ) -> io::Result<()> { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); + if let Some(outer_ref) = persisted_lock.get_mut(&prefixed) { + outer_ref.remove(&key.to_string()); + } + + Ok(()) + } + + fn list_internal( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> io::Result> { + let mut persisted_lock = self.persisted_bytes.lock().unwrap(); + + let prefixed = format!("{primary_namespace}/{secondary_namespace}"); + match persisted_lock.entry(prefixed) { + hash_map::Entry::Occupied(e) => Ok(e.get().keys().cloned().collect()), + hash_map::Entry::Vacant(_) => Ok(Vec::new()), + } + } +} + +impl KVStore for InMemoryStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> Pin, io::Error>> + 'static + Send>> { + let res = self.read_internal(&primary_namespace, &secondary_namespace, &key); + Box::pin(async move { res }) + } + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> Pin> + 'static + Send>> { + let res = self.write_internal(&primary_namespace, &secondary_namespace, &key, buf); + Box::pin(async move { res }) + } + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, + ) -> Pin> + 'static + Send>> { + let res = self.remove_internal(&primary_namespace, &secondary_namespace, &key, lazy); + Box::pin(async move { res }) + } + fn list( + &self, primary_namespace: &str, secondary_namespace: &str, + ) -> Pin, io::Error>> + 'static + Send>> { + let res = self.list_internal(primary_namespace, secondary_namespace); + Box::pin(async move { res }) + } +} + +impl KVStoreSync for InMemoryStore { + fn read( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, + ) -> io::Result> { + self.read_internal(primary_namespace, secondary_namespace, key) + } + + fn write( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, buf: Vec, + ) -> io::Result<()> { + self.write_internal(primary_namespace, secondary_namespace, key, buf) + } + + fn remove( + &self, primary_namespace: &str, secondary_namespace: &str, key: &str, lazy: bool, + ) -> io::Result<()> { + self.remove_internal(primary_namespace, secondary_namespace, key, lazy) + } + + fn list(&self, primary_namespace: &str, secondary_namespace: &str) -> io::Result> { + self.list_internal(primary_namespace, secondary_namespace) + } +} + +unsafe impl Sync for InMemoryStore {} +unsafe impl Send for InMemoryStore {} + pub(crate) fn random_storage_path() -> PathBuf { let mut temp_path = std::env::temp_dir(); let mut rng = rng(); diff --git a/src/payment/asynchronous/static_invoice_store.rs b/src/payment/asynchronous/static_invoice_store.rs index a7e2d2f9e..45125cfee 100644 --- a/src/payment/asynchronous/static_invoice_store.rs +++ b/src/payment/asynchronous/static_invoice_store.rs @@ -157,15 +157,15 @@ mod tests { use lightning::offers::offer::OfferBuilder; use lightning::offers::static_invoice::{StaticInvoice, StaticInvoiceBuilder}; use lightning::sign::EntropySource; - use lightning::util::test_utils::TestStore; use lightning_types::features::BlindedHopFeatures; + use crate::io::test_utils::InMemoryStore; use crate::payment::asynchronous::static_invoice_store::StaticInvoiceStore; use crate::types::DynStore; #[tokio::test] async fn static_invoice_store_test() { - let store: Arc = Arc::new(TestStore::new(false)); + let store: Arc = Arc::new(InMemoryStore::new()); let static_invoice_store = StaticInvoiceStore::new(Arc::clone(&store)); let static_invoice = invoice(); diff --git a/src/peer_store.rs b/src/peer_store.rs index 82c80c396..59cd3d94f 100644 --- a/src/peer_store.rs +++ b/src/peer_store.rs @@ -152,13 +152,14 @@ mod tests { use std::str::FromStr; use std::sync::Arc; - use lightning::util::test_utils::{TestLogger, TestStore}; + use lightning::util::test_utils::TestLogger; use super::*; + use crate::io::test_utils::InMemoryStore; #[test] fn peer_info_persistence() { - let store: Arc = Arc::new(TestStore::new(false)); + let store: Arc = Arc::new(InMemoryStore::new()); let logger = Arc::new(TestLogger::new()); let peer_store = PeerStore::new(Arc::clone(&store), Arc::clone(&logger)); From 4c7254139dc2eb09424cefda7c8c357b71d69c14 Mon Sep 17 00:00:00 2001 From: Elias Rohrer Date: Mon, 10 Nov 2025 13:48:51 +0100 Subject: [PATCH 2/2] Make `EventQueue` persistence `async` Previously, we'd still use `KVStoreSync` for persistence of our event queue, which also meant calling the sync persistence through our otherwise-async background processor/event handling flow. Here we switch our `EventQueue` persistence to be async, which gets us one step further towards async-everything. --- src/event.rs | 63 ++++++++++++++++++++++++++++------------------------ src/lib.rs | 5 ++++- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/event.rs b/src/event.rs index 42b60e213..3de2c3261 100644 --- a/src/event.rs +++ b/src/event.rs @@ -26,7 +26,7 @@ use lightning::util::config::{ ChannelConfigOverrides, ChannelConfigUpdate, ChannelHandshakeConfigUpdate, }; use lightning::util::errors::APIError; -use lightning::util::persist::KVStoreSync; +use lightning::util::persist::KVStore; use lightning::util::ser::{Readable, ReadableArgs, Writeable, Writer}; use lightning_liquidity::lsps2::utils::compute_opening_fee; use lightning_types::payment::{PaymentHash, PaymentPreimage}; @@ -301,12 +301,14 @@ where Self { queue, waker, kv_store, logger } } - pub(crate) fn add_event(&self, event: Event) -> Result<(), Error> { - { + pub(crate) async fn add_event(&self, event: Event) -> Result<(), Error> { + let data = { let mut locked_queue = self.queue.lock().unwrap(); locked_queue.push_back(event); - self.persist_queue(&locked_queue)?; - } + EventQueueSerWrapper(&locked_queue).encode() + }; + + self.persist_queue(data).await?; if let Some(waker) = self.waker.lock().unwrap().take() { waker.wake(); @@ -323,12 +325,14 @@ where EventFuture { event_queue: Arc::clone(&self.queue), waker: Arc::clone(&self.waker) }.await } - pub(crate) fn event_handled(&self) -> Result<(), Error> { - { + pub(crate) async fn event_handled(&self) -> Result<(), Error> { + let data = { let mut locked_queue = self.queue.lock().unwrap(); locked_queue.pop_front(); - self.persist_queue(&locked_queue)?; - } + EventQueueSerWrapper(&locked_queue).encode() + }; + + self.persist_queue(data).await?; if let Some(waker) = self.waker.lock().unwrap().take() { waker.wake(); @@ -336,15 +340,15 @@ where Ok(()) } - fn persist_queue(&self, locked_queue: &VecDeque) -> Result<(), Error> { - let data = EventQueueSerWrapper(locked_queue).encode(); - KVStoreSync::write( + async fn persist_queue(&self, encoded_queue: Vec) -> Result<(), Error> { + KVStore::write( &*self.kv_store, EVENT_QUEUE_PERSISTENCE_PRIMARY_NAMESPACE, EVENT_QUEUE_PERSISTENCE_SECONDARY_NAMESPACE, EVENT_QUEUE_PERSISTENCE_KEY, - data, + encoded_queue, ) + .await .map_err(|e| { log_error!( self.logger, @@ -694,7 +698,7 @@ where claim_deadline, custom_records, }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => return Ok(()), Err(e) => { log_error!( @@ -928,7 +932,7 @@ where .map(|cf| cf.custom_tlvs().into_iter().map(|tlv| tlv.into()).collect()) .unwrap_or_default(), }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => return Ok(()), Err(e) => { log_error!(self.logger, "Failed to push to event queue: {}", e); @@ -988,7 +992,7 @@ where fee_paid_msat, }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => return Ok(()), Err(e) => { log_error!(self.logger, "Failed to push to event queue: {}", e); @@ -1019,7 +1023,7 @@ where let event = Event::PaymentFailed { payment_id: Some(payment_id), payment_hash, reason }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => return Ok(()), Err(e) => { log_error!(self.logger, "Failed to push to event queue: {}", e); @@ -1295,7 +1299,7 @@ where claim_from_onchain_tx, outbound_amount_forwarded_msat, }; - self.event_queue.add_event(event).map_err(|e| { + self.event_queue.add_event(event).await.map_err(|e| { log_error!(self.logger, "Failed to push to event queue: {}", e); ReplayEvent() })?; @@ -1322,7 +1326,7 @@ where counterparty_node_id, funding_txo, }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => {}, Err(e) => { log_error!(self.logger, "Failed to push to event queue: {}", e); @@ -1383,7 +1387,7 @@ where user_channel_id: UserChannelId(user_channel_id), counterparty_node_id: Some(counterparty_node_id), }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => {}, Err(e) => { log_error!(self.logger, "Failed to push to event queue: {}", e); @@ -1407,7 +1411,7 @@ where reason: Some(reason), }; - match self.event_queue.add_event(event) { + match self.event_queue.add_event(event).await { Ok(_) => {}, Err(e) => { log_error!(self.logger, "Failed to push to event queue: {}", e); @@ -1622,7 +1626,7 @@ mod tests { user_channel_id: UserChannelId(2323), counterparty_node_id: None, }; - event_queue.add_event(expected_event.clone()).unwrap(); + event_queue.add_event(expected_event.clone()).await.unwrap(); // Check we get the expected event and that it is returned until we mark it handled. for _ in 0..5 { @@ -1631,18 +1635,19 @@ mod tests { } // Check we can read back what we persisted. - let persisted_bytes = KVStoreSync::read( + let persisted_bytes = KVStore::read( &*store, EVENT_QUEUE_PERSISTENCE_PRIMARY_NAMESPACE, EVENT_QUEUE_PERSISTENCE_SECONDARY_NAMESPACE, EVENT_QUEUE_PERSISTENCE_KEY, ) + .await .unwrap(); let deser_event_queue = EventQueue::read(&mut &persisted_bytes[..], (Arc::clone(&store), logger)).unwrap(); assert_eq!(deser_event_queue.next_event_async().await, expected_event); - event_queue.event_handled().unwrap(); + event_queue.event_handled().await.unwrap(); assert_eq!(event_queue.next_event(), None); } @@ -1676,28 +1681,28 @@ mod tests { let mut delayed_enqueue = false; for _ in 0..25 { - event_queue.add_event(expected_event.clone()).unwrap(); + event_queue.add_event(expected_event.clone()).await.unwrap(); enqueued_events.fetch_add(1, Ordering::SeqCst); } loop { tokio::select! { _ = tokio::time::sleep(Duration::from_millis(10)), if !delayed_enqueue => { - event_queue.add_event(expected_event.clone()).unwrap(); + event_queue.add_event(expected_event.clone()).await.unwrap(); enqueued_events.fetch_add(1, Ordering::SeqCst); delayed_enqueue = true; } e = event_queue.next_event_async() => { assert_eq!(e, expected_event); - event_queue.event_handled().unwrap(); + event_queue.event_handled().await.unwrap(); received_events.fetch_add(1, Ordering::SeqCst); - event_queue.add_event(expected_event.clone()).unwrap(); + event_queue.add_event(expected_event.clone()).await.unwrap(); enqueued_events.fetch_add(1, Ordering::SeqCst); } e = event_queue.next_event_async() => { assert_eq!(e, expected_event); - event_queue.event_handled().unwrap(); + event_queue.event_handled().await.unwrap(); received_events.fetch_add(1, Ordering::SeqCst); } } diff --git a/src/lib.rs b/src/lib.rs index 701a14dde..982673f4a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -777,7 +777,10 @@ impl Node { /// /// **Note:** This **MUST** be called after each event has been handled. pub fn event_handled(&self) -> Result<(), Error> { - self.event_queue.event_handled().map_err(|e| { + // We use our runtime for the sync variant to ensure `tokio::task::block_in_place` is + // always called if we'd ever hit this in an outer runtime context. + let fut = self.event_queue.event_handled(); + self.runtime.block_on(fut).map_err(|e| { log_error!( self.logger, "Couldn't mark event handled due to persistence failure: {}",