Skip to content

Commit fbaaa8d

Browse files
Use FnMut closures for receiver typestate checks (#883)
In our liana wallet integration we are encountering some issues due to how the liana devs have setup their database, namely that their db always references a mutable self even on read calls within the db so we need to have some mutability when accessing things like an address or outpoint from the db. This was not possible when our closures were limited to non-mutable. I considered using FnOnce but we have some map methods that prevent us from doing that as map have the possiblity of forcing the closure being called more than once which FnOnce is limited to. N.B. I did not make all of the closures for the typestates mutable but perhaps we should think about their consistency for integration devs and to limit rust-payjoin from deciding how devs should implement these type state checks.
2 parents b9b9948 + d81d8ae commit fbaaa8d

File tree

8 files changed

+134
-51
lines changed

8 files changed

+134
-51
lines changed

payjoin-cli/src/app/v1.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,17 +321,18 @@ impl App {
321321
let _to_broadcast_in_failure_case = proposal.extract_tx_to_schedule_broadcast();
322322

323323
// Receive Check 2: receiver can't sign for proposal inputs
324-
let proposal = proposal.check_inputs_not_owned(|input| {
324+
let proposal = proposal.check_inputs_not_owned(&mut |input| {
325325
wallet.is_mine(input).map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
326326
})?;
327327
log::trace!("check2");
328328

329329
// Receive Check 3: have we seen this input before? More of a check for non-interactive i.e. payment processor receivers.
330-
let payjoin = proposal
331-
.check_no_inputs_seen_before(|input| Ok(self.db.insert_input_seen_before(*input)?))?;
330+
let payjoin = proposal.check_no_inputs_seen_before(&mut |input| {
331+
Ok(self.db.insert_input_seen_before(*input)?)
332+
})?;
332333
log::trace!("check3");
333334

334-
let payjoin = payjoin.identify_receiver_outputs(|output_script| {
335+
let payjoin = payjoin.identify_receiver_outputs(&mut |output_script| {
335336
wallet
336337
.is_mine(output_script)
337338
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))

payjoin-cli/src/app/v2/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ impl App {
374374
) -> Result<()> {
375375
let wallet = self.wallet();
376376
let proposal = proposal
377-
.check_inputs_not_owned(|input| {
377+
.check_inputs_not_owned(&mut |input| {
378378
wallet
379379
.is_mine(input)
380380
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))
@@ -389,7 +389,9 @@ impl App {
389389
persister: &ReceiverPersister,
390390
) -> Result<()> {
391391
let proposal = proposal
392-
.check_no_inputs_seen_before(|input| Ok(self.db.insert_input_seen_before(*input)?))
392+
.check_no_inputs_seen_before(&mut |input| {
393+
Ok(self.db.insert_input_seen_before(*input)?)
394+
})
393395
.save(persister)?;
394396
self.identify_receiver_outputs(proposal, persister).await
395397
}
@@ -401,7 +403,7 @@ impl App {
401403
) -> Result<()> {
402404
let wallet = self.wallet();
403405
let proposal = proposal
404-
.identify_receiver_outputs(|output_script| {
406+
.identify_receiver_outputs(&mut |output_script| {
405407
wallet
406408
.is_mine(output_script)
407409
.map_err(|e| ImplementationError::from(e.into_boxed_dyn_error()))

payjoin-ffi/src/receive/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ impl MaybeInputsOwned {
417417
is_owned: impl Fn(&Vec<u8>) -> Result<bool, ImplementationError>,
418418
) -> MaybeInputsOwnedTransition {
419419
MaybeInputsOwnedTransition(Arc::new(RwLock::new(Some(
420-
self.0.clone().check_inputs_not_owned(|input| Ok(is_owned(&input.to_bytes())?)),
420+
self.0.clone().check_inputs_not_owned(&mut |input| Ok(is_owned(&input.to_bytes())?)),
421421
))))
422422
}
423423
}
@@ -473,7 +473,7 @@ impl MaybeInputsSeen {
473473
MaybeInputsSeenTransition(Arc::new(RwLock::new(Some(
474474
self.0
475475
.clone()
476-
.check_no_inputs_seen_before(|outpoint| Ok(is_known(&(*outpoint).into())?)),
476+
.check_no_inputs_seen_before(&mut |outpoint| Ok(is_known(&(*outpoint).into())?)),
477477
))))
478478
}
479479
}
@@ -532,7 +532,7 @@ impl OutputsUnknown {
532532
OutputsUnknownTransition(Arc::new(RwLock::new(Some(
533533
self.0
534534
.clone()
535-
.identify_receiver_outputs(|input| Ok(is_receiver_output(&input.to_bytes())?)),
535+
.identify_receiver_outputs(&mut |input| Ok(is_receiver_output(&input.to_bytes())?)),
536536
))))
537537
}
538538
}

payjoin/src/core/receive/multiparty/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ pub struct MaybeInputsOwned {
112112
impl MaybeInputsOwned {
113113
pub fn check_inputs_not_owned(
114114
self,
115-
is_owned: impl Fn(&bitcoin::Script) -> Result<bool, ImplementationError>,
115+
is_owned: &mut impl FnMut(&bitcoin::Script) -> Result<bool, ImplementationError>,
116116
) -> Result<MaybeInputsSeen, Error> {
117117
let inner = self.v1.check_inputs_not_owned(is_owned)?;
118118
Ok(MaybeInputsSeen { v1: inner, contexts: self.contexts })
@@ -127,7 +127,7 @@ pub struct MaybeInputsSeen {
127127
impl MaybeInputsSeen {
128128
pub fn check_no_inputs_seen_before(
129129
self,
130-
is_seen: impl Fn(&bitcoin::OutPoint) -> Result<bool, ImplementationError>,
130+
is_seen: &mut impl FnMut(&bitcoin::OutPoint) -> Result<bool, ImplementationError>,
131131
) -> Result<OutputsUnknown, Error> {
132132
let inner = self.v1.check_no_inputs_seen_before(is_seen)?;
133133
Ok(OutputsUnknown { v1: inner, contexts: self.contexts })
@@ -142,7 +142,7 @@ pub struct OutputsUnknown {
142142
impl OutputsUnknown {
143143
pub fn identify_receiver_outputs(
144144
self,
145-
is_receiver_output: impl Fn(&bitcoin::Script) -> Result<bool, ImplementationError>,
145+
is_receiver_output: &mut impl FnMut(&bitcoin::Script) -> Result<bool, ImplementationError>,
146146
) -> Result<WantsOutputs, Error> {
147147
let inner = self.v1.identify_receiver_outputs(is_receiver_output)?;
148148
Ok(WantsOutputs { v1: inner, contexts: self.contexts })

payjoin/src/core/receive/v1/mod.rs

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ impl UncheckedProposal {
141141
/// Call [`Self::check_inputs_not_owned`] to proceed.
142142
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
143143
pub struct MaybeInputsOwned {
144-
psbt: Psbt,
145-
params: Params,
144+
pub(crate) psbt: Psbt,
145+
pub(crate) params: Params,
146146
}
147147

148148
impl MaybeInputsOwned {
@@ -160,7 +160,7 @@ impl MaybeInputsOwned {
160160
/// An attacker can try to spend the receiver's own inputs. This check prevents that.
161161
pub fn check_inputs_not_owned(
162162
self,
163-
is_owned: impl Fn(&Script) -> Result<bool, ImplementationError>,
163+
is_owned: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
164164
) -> Result<MaybeInputsSeen, ReplyableError> {
165165
let mut err: Result<(), ReplyableError> = Ok(());
166166
if let Some(e) = self
@@ -206,7 +206,7 @@ impl MaybeInputsSeen {
206206
/// original proposal PSBT of the current, new payjoin.
207207
pub fn check_no_inputs_seen_before(
208208
self,
209-
is_known: impl Fn(&OutPoint) -> Result<bool, ImplementationError>,
209+
is_known: &mut impl FnMut(&OutPoint) -> Result<bool, ImplementationError>,
210210
) -> Result<OutputsUnknown, ReplyableError> {
211211
self.psbt.input_pairs().try_for_each(|input| {
212212
match is_known(&input.txin.previous_output) {
@@ -248,7 +248,7 @@ impl OutputsUnknown {
248248
/// outputs.
249249
pub fn identify_receiver_outputs(
250250
self,
251-
is_receiver_output: impl Fn(&Script) -> Result<bool, ImplementationError>,
251+
is_receiver_output: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
252252
) -> Result<WantsOutputs, ReplyableError> {
253253
let owned_vouts: Vec<usize> = self
254254
.psbt
@@ -899,14 +899,21 @@ pub(crate) mod test {
899899
UncheckedProposal { psbt: PARSED_ORIGINAL_PSBT.clone(), params }
900900
}
901901

902+
pub(crate) fn maybe_inputs_owned_from_test_vector() -> MaybeInputsOwned {
903+
let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes());
904+
let params = Params::from_query_pairs(pairs, &[Version::One])
905+
.expect("Could not parse params from query pairs");
906+
MaybeInputsOwned { psbt: PARSED_ORIGINAL_PSBT.clone(), params }
907+
}
908+
902909
fn wants_outputs_from_test_vector(proposal: UncheckedProposal) -> WantsOutputs {
903910
proposal
904911
.assume_interactive_receiver()
905-
.check_inputs_not_owned(|_| Ok(false))
912+
.check_inputs_not_owned(&mut |_| Ok(false))
906913
.expect("No inputs should be owned")
907-
.check_no_inputs_seen_before(|_| Ok(false))
914+
.check_no_inputs_seen_before(&mut |_| Ok(false))
908915
.expect("No inputs should be seen before")
909-
.identify_receiver_outputs(|script| {
916+
.identify_receiver_outputs(&mut |script| {
910917
let network = Network::Bitcoin;
911918
Ok(Address::from_script(script, network).unwrap()
912919
== Address::from_str("3CZZi7aWFugaCdUCS15dgrUUViupmB8bVM")
@@ -925,6 +932,35 @@ pub(crate) mod test {
925932
.expect("Contributed inputs should allow for valid fee contributions")
926933
}
927934

935+
#[test]
936+
fn test_mutable_receiver_state_closures() {
937+
let mut call_count = 0;
938+
let maybe_inputs_owned = maybe_inputs_owned_from_test_vector();
939+
940+
fn mock_callback(call_count: &mut usize, ret: bool) -> Result<bool, ImplementationError> {
941+
*call_count += 1;
942+
Ok(ret)
943+
}
944+
945+
let maybe_inputs_seen = maybe_inputs_owned
946+
.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false));
947+
assert_eq!(call_count, 1);
948+
949+
let outputs_unknown = maybe_inputs_seen
950+
.map_err(|_| "Check inputs owned closure failed".to_string())
951+
.expect("Next receiver state should be accessible")
952+
.check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false));
953+
assert_eq!(call_count, 2);
954+
955+
let _wants_outputs = outputs_unknown
956+
.map_err(|_| "Check no inputs seen closure failed".to_string())
957+
.expect("Next receiver state should be accessible")
958+
.identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true));
959+
// there are 2 receiver outputs so we should expect this callback to run twice incrementing
960+
// call count twice
961+
assert_eq!(call_count, 4);
962+
}
963+
928964
#[test]
929965
fn is_output_substitution_disabled() {
930966
let mut proposal = unchecked_proposal_from_test_vector();
@@ -975,11 +1011,11 @@ pub(crate) mod test {
9751011
let proposal = unchecked_proposal_from_test_vector();
9761012
let wants_inputs = proposal
9771013
.assume_interactive_receiver()
978-
.check_inputs_not_owned(|_| Ok(false))
1014+
.check_inputs_not_owned(&mut |_| Ok(false))
9791015
.expect("No inputs should be owned")
980-
.check_no_inputs_seen_before(|_| Ok(false))
1016+
.check_no_inputs_seen_before(&mut |_| Ok(false))
9811017
.expect("No inputs should be seen before")
982-
.identify_receiver_outputs(|script| {
1018+
.identify_receiver_outputs(&mut |script| {
9831019
let network = Network::Bitcoin;
9841020
let target_address = Address::from_str("3CZZi7aWFugaCdUCS15dgrUUViupmB8bVM")
9851021
.map_err(ImplementationError::new)?

payjoin/src/core/receive/v2/mod.rs

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ impl Receiver<MaybeInputsOwned> {
568568
/// An attacker can try to spend the receiver's own inputs. This check prevents that.
569569
pub fn check_inputs_not_owned(
570570
self,
571-
is_owned: impl Fn(&Script) -> Result<bool, ImplementationError>,
571+
is_owned: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
572572
) -> MaybeFatalTransition<SessionEvent, Receiver<MaybeInputsSeen>, ReplyableError> {
573573
let inner = match self.state.v1.clone().check_inputs_not_owned(is_owned) {
574574
Ok(inner) => inner,
@@ -618,7 +618,7 @@ impl Receiver<MaybeInputsSeen> {
618618
/// original proposal PSBT of the current, new payjoin.
619619
pub fn check_no_inputs_seen_before(
620620
self,
621-
is_known: impl Fn(&OutPoint) -> Result<bool, ImplementationError>,
621+
is_known: &mut impl FnMut(&OutPoint) -> Result<bool, ImplementationError>,
622622
) -> MaybeFatalTransition<SessionEvent, Receiver<OutputsUnknown>, ReplyableError> {
623623
let inner = match self.state.v1.clone().check_no_inputs_seen_before(is_known) {
624624
Ok(inner) => inner,
@@ -673,7 +673,7 @@ impl Receiver<OutputsUnknown> {
673673
/// outputs.
674674
pub fn identify_receiver_outputs(
675675
self,
676-
is_receiver_output: impl Fn(&Script) -> Result<bool, ImplementationError>,
676+
is_receiver_output: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
677677
) -> MaybeFatalTransition<SessionEvent, Receiver<WantsOutputs>, ReplyableError> {
678678
let inner = match self.state.inner.clone().identify_receiver_outputs(is_receiver_output) {
679679
Ok(inner) => inner,
@@ -1099,6 +1099,50 @@ pub mod test {
10991099
}
11001100
}
11011101

1102+
pub(crate) fn maybe_inputs_owned_v2_from_test_vector() -> MaybeInputsOwned {
1103+
let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes());
1104+
let params = Params::from_query_pairs(pairs, &[Version::Two])
1105+
.expect("Test utils query params should not fail");
1106+
MaybeInputsOwned {
1107+
v1: v1::MaybeInputsOwned { psbt: PARSED_ORIGINAL_PSBT.clone(), params },
1108+
context: SHARED_CONTEXT.clone(),
1109+
}
1110+
}
1111+
1112+
#[test]
1113+
fn test_v2_mutable_receiver_state_closures() {
1114+
let mut call_count = 0;
1115+
let maybe_inputs_owned = maybe_inputs_owned_v2_from_test_vector();
1116+
let receiver = v2::Receiver { state: maybe_inputs_owned };
1117+
1118+
fn mock_callback(call_count: &mut usize, ret: bool) -> Result<bool, ImplementationError> {
1119+
*call_count += 1;
1120+
Ok(ret)
1121+
}
1122+
1123+
let maybe_inputs_seen =
1124+
receiver.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false));
1125+
assert_eq!(call_count, 1);
1126+
1127+
let outputs_unknown = maybe_inputs_seen
1128+
.0
1129+
.map_err(|_| "Check inputs owned closure failed".to_string())
1130+
.expect("Next receiver state should be accessible")
1131+
.1
1132+
.check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false));
1133+
assert_eq!(call_count, 2);
1134+
1135+
let _wants_outputs = outputs_unknown
1136+
.0
1137+
.map_err(|_| "Check no inputs seen closure failed".to_string())
1138+
.expect("Next receiver state should be accessible")
1139+
.1
1140+
.identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true));
1141+
// there are 2 receiver outputs so we should expect this callback to run twice incrementing
1142+
// call count twice
1143+
assert_eq!(call_count, 4);
1144+
}
1145+
11021146
#[test]
11031147
fn test_unchecked_proposal_transient_error() -> Result<(), BoxError> {
11041148
let unchecked_proposal = unchecked_proposal_v2_from_test_vector();
@@ -1127,7 +1171,7 @@ pub mod test {
11271171
let receiver = v2::Receiver { state: unchecked_proposal };
11281172

11291173
let maybe_inputs_owned = receiver.assume_interactive_receiver();
1130-
let maybe_inputs_seen = maybe_inputs_owned.0 .1.check_inputs_not_owned(|_| {
1174+
let maybe_inputs_seen = maybe_inputs_owned.0 .1.check_inputs_not_owned(&mut |_| {
11311175
Err(ImplementationError::new(ReplyableError::Implementation("mock error".into())))
11321176
});
11331177

@@ -1150,9 +1194,9 @@ pub mod test {
11501194
let receiver = v2::Receiver { state: unchecked_proposal };
11511195

11521196
let maybe_inputs_owned = receiver.assume_interactive_receiver();
1153-
let maybe_inputs_seen = maybe_inputs_owned.0 .1.check_inputs_not_owned(|_| Ok(false));
1197+
let maybe_inputs_seen = maybe_inputs_owned.0 .1.check_inputs_not_owned(&mut |_| Ok(false));
11541198
let outputs_unknown = match maybe_inputs_seen.0 {
1155-
Ok(state) => state.1.check_no_inputs_seen_before(|_| {
1199+
Ok(state) => state.1.check_no_inputs_seen_before(&mut |_| {
11561200
Err(ImplementationError::new(ReplyableError::Implementation("mock error".into())))
11571201
}),
11581202
Err(_) => panic!("Expected Ok, got Err"),
@@ -1177,13 +1221,13 @@ pub mod test {
11771221
let receiver = v2::Receiver { state: unchecked_proposal };
11781222

11791223
let maybe_inputs_owned = receiver.assume_interactive_receiver();
1180-
let maybe_inputs_seen = maybe_inputs_owned.0 .1.check_inputs_not_owned(|_| Ok(false));
1224+
let maybe_inputs_seen = maybe_inputs_owned.0 .1.check_inputs_not_owned(&mut |_| Ok(false));
11811225
let outputs_unknown = match maybe_inputs_seen.0 {
1182-
Ok(state) => state.1.check_no_inputs_seen_before(|_| Ok(false)),
1226+
Ok(state) => state.1.check_no_inputs_seen_before(&mut |_| Ok(false)),
11831227
Err(_) => panic!("Expected Ok, got Err"),
11841228
};
11851229
let wants_outputs = match outputs_unknown.0 {
1186-
Ok(state) => state.1.identify_receiver_outputs(|_| {
1230+
Ok(state) => state.1.identify_receiver_outputs(&mut |_| {
11871231
Err(ImplementationError::new(ReplyableError::Implementation("mock error".into())))
11881232
}),
11891233
Err(_) => panic!("Expected Ok, got Err"),

payjoin/src/core/receive/v2/session.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,15 @@ mod tests {
178178
let maybe_inputs_owned = unchecked_proposal.clone().assume_interactive_receiver();
179179
let maybe_inputs_seen = maybe_inputs_owned
180180
.clone()
181-
.check_inputs_not_owned(|_| Ok(false))
181+
.check_inputs_not_owned(&mut |_| Ok(false))
182182
.expect("No inputs should be owned");
183183
let outputs_unknown = maybe_inputs_seen
184184
.clone()
185-
.check_no_inputs_seen_before(|_| Ok(false))
185+
.check_no_inputs_seen_before(&mut |_| Ok(false))
186186
.expect("No inputs should be seen before");
187187
let wants_outputs = outputs_unknown
188188
.clone()
189-
.identify_receiver_outputs(|_| Ok(true))
189+
.identify_receiver_outputs(&mut |_| Ok(true))
190190
.expect("Outputs should be identified");
191191
let wants_inputs = wants_outputs.clone().commit_outputs();
192192
let wants_fee_range = wants_inputs.clone().commit_inputs();
@@ -375,15 +375,15 @@ mod tests {
375375
let maybe_inputs_owned = unchecked_proposal.clone().assume_interactive_receiver();
376376
let maybe_inputs_seen = maybe_inputs_owned
377377
.clone()
378-
.check_inputs_not_owned(|_| Ok(false))
378+
.check_inputs_not_owned(&mut |_| Ok(false))
379379
.expect("No inputs should be owned");
380380
let outputs_unknown = maybe_inputs_seen
381381
.clone()
382-
.check_no_inputs_seen_before(|_| Ok(false))
382+
.check_no_inputs_seen_before(&mut |_| Ok(false))
383383
.expect("No inputs should be seen before");
384384
let wants_outputs = outputs_unknown
385385
.clone()
386-
.identify_receiver_outputs(|_| Ok(true))
386+
.identify_receiver_outputs(&mut |_| Ok(true))
387387
.expect("Outputs should be identified");
388388
let wants_inputs = wants_outputs.clone().commit_outputs();
389389
let wants_fee_range = wants_inputs.clone().commit_inputs();
@@ -425,15 +425,15 @@ mod tests {
425425
let maybe_inputs_owned = unchecked_proposal.clone().assume_interactive_receiver();
426426
let maybe_inputs_seen = maybe_inputs_owned
427427
.clone()
428-
.check_inputs_not_owned(|_| Ok(false))
428+
.check_inputs_not_owned(&mut |_| Ok(false))
429429
.expect("No inputs should be owned");
430430
let outputs_unknown = maybe_inputs_seen
431431
.clone()
432-
.check_no_inputs_seen_before(|_| Ok(false))
432+
.check_no_inputs_seen_before(&mut |_| Ok(false))
433433
.expect("No inputs should be seen before");
434434
let wants_outputs = outputs_unknown
435435
.clone()
436-
.identify_receiver_outputs(|_| Ok(true))
436+
.identify_receiver_outputs(&mut |_| Ok(true))
437437
.expect("Outputs should be identified");
438438
let wants_inputs = wants_outputs.clone().commit_outputs();
439439
let wants_fee_range = wants_inputs.clone().commit_inputs();

0 commit comments

Comments
 (0)