From d5be1f1b439c7648b81c49f7c04f03165ee719ef Mon Sep 17 00:00:00 2001 From: "J.C. Jones" Date: Tue, 4 Aug 2020 20:52:10 -0700 Subject: [PATCH] Add a higher-level AuthenticatorService that can query multiple backends - This moves the callback mechanism into its own file, as it gets more complex - Reworks the C API to use the AuthenticatorService --- Cargo.toml | 1 + examples/main.rs | 111 ++++--- src/authenticatorservice.rs | 618 ++++++++++++++++++++++++++++++++++++ src/capi.rs | 75 +++-- src/freebsd/transaction.rs | 2 +- src/lib.rs | 13 +- src/linux/transaction.rs | 2 +- src/macos/transaction.rs | 2 +- src/manager.rs | 33 +- src/netbsd/transaction.rs | 2 +- src/openbsd/transaction.rs | 2 +- src/statecallback.rs | 162 ++++++++++ src/statemachine.rs | 3 +- src/stub/transaction.rs | 2 +- src/util.rs | 29 -- src/windows/transaction.rs | 2 +- 16 files changed, 921 insertions(+), 138 deletions(-) create mode 100644 src/authenticatorservice.rs create mode 100644 src/statecallback.rs diff --git a/Cargo.toml b/Cargo.toml index c94279d8..f8ab2ea9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,3 +49,4 @@ bitflags = "1.0" sha2 = "^0.8.2" base64 = "^0.10" env_logger = "^0.6" +getopts = "^0.2" diff --git a/examples/main.rs b/examples/main.rs index 3128b811..668d1cc1 100644 --- a/examples/main.rs +++ b/examples/main.rs @@ -2,29 +2,14 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -extern crate authenticator; -extern crate base64; -extern crate sha2; use authenticator::{ - AuthenticatorTransports, KeyHandle, RegisterFlags, SignFlags, StatusUpdate, U2FManager, + authenticatorservice::AuthenticatorService, statecallback::StateCallback, + AuthenticatorTransports, KeyHandle, RegisterFlags, SignFlags, StatusUpdate, }; +use getopts::Options; use sha2::{Digest, Sha256}; use std::sync::mpsc::{channel, RecvError}; -use std::{io, thread}; - -extern crate env_logger; -extern crate log; - -macro_rules! try_or { - ($val:expr, $or:expr) => { - match $val { - Ok(v) => v, - Err(e) => { - return $or(e); - } - } - }; -} +use std::{env, io, thread}; fn u2f_get_key_handle_from_register_response(register_response: &[u8]) -> io::Result> { if register_response[0] != 0x05 { @@ -42,9 +27,37 @@ fn u2f_get_key_handle_from_register_response(register_response: &[u8]) -> io::Re Ok(key_handle) } +fn print_usage(program: &str, opts: Options) { + let brief = format!("Usage: {} [options]", program); + print!("{}", opts.usage(&brief)); +} + fn main() { env_logger::init(); + let args: Vec = env::args().collect(); + let program = args[0].clone(); + + let mut opts = Options::new(); + opts.optflag("x", "no-u2f-usb-hid", "do not enable u2f-usb-hid platforms"); + + opts.optflag("h", "help", "print this help menu"); + let matches = match opts.parse(&args[1..]) { + Ok(m) => m, + Err(f) => panic!(f.to_string()), + }; + if matches.opt_present("help") { + print_usage(&program, opts); + return; + } + + let mut manager = + AuthenticatorService::new().expect("The auth service should initialize safely"); + + if !matches.opt_present("no-u2f-usb-hid") { + manager.add_u2f_usb_hid_platform_transports(); + } + println!("Asking a security key to register now..."); let challenge_str = format!( "{}{}", @@ -59,7 +72,6 @@ fn main() { application.input(b"http://demo.yubico.com"); let app_bytes = application.result().to_vec(); - let manager = U2FManager::new().unwrap(); let flags = RegisterFlags::empty(); let (status_tx, status_rx) = channel::(); @@ -82,25 +94,26 @@ fn main() { }); let (register_tx, register_rx) = channel(); + let callback = StateCallback::new(Box::new(move |rv| { + register_tx.send(rv).unwrap(); + })); + manager .register( flags, - 15_000, + 60_000 * 5, chall_bytes.clone(), app_bytes.clone(), vec![], status_tx.clone(), - move |rv| { - register_tx.send(rv).unwrap(); - }, + callback, ) - .unwrap(); + .expect("Couldn't register"); - let register_result = try_or!(register_rx.recv(), |_| { - panic!("Problem receiving, unable to continue"); - }); - let (register_data, device_info) = - register_result.unwrap_or_else(|e| panic!("Registration failed: {:?}", e)); + let register_result = register_rx + .recv() + .expect("Problem receiving, unable to continue"); + let (register_data, device_info) = register_result.expect("Registration failed"); println!("Register result: {}", base64::encode(®ister_data)); println!("Device info: {}", &device_info); @@ -113,25 +126,27 @@ fn main() { let flags = SignFlags::empty(); let (sign_tx, sign_rx) = channel(); - manager - .sign( - flags, - 15_000, - chall_bytes, - vec![app_bytes], - vec![key_handle], - status_tx, - move |rv| { - sign_tx.send(rv).unwrap(); - }, - ) - .unwrap(); - let sign_result = try_or!(sign_rx.recv(), |_| { - panic!("Problem receiving, unable to continue"); - }); - let (_, handle_used, sign_data, device_info) = - sign_result.unwrap_or_else(|e| panic!("Sign failed: {:?}", e)); + let callback = StateCallback::new(Box::new(move |rv| { + sign_tx.send(rv).unwrap(); + })); + + if let Err(e) = manager.sign( + flags, + 15_000, + chall_bytes, + vec![app_bytes], + vec![key_handle], + status_tx, + callback, + ) { + panic!("Couldn't register: {:?}", e); + } + + let sign_result = sign_rx + .recv() + .expect("Problem receiving, unable to continue"); + let (_, handle_used, sign_data, device_info) = sign_result.expect("Sign failed"); println!("Sign result: {}", base64::encode(&sign_data)); println!("Key handle used: {}", base64::encode(&handle_used)); diff --git a/src/authenticatorservice.rs b/src/authenticatorservice.rs new file mode 100644 index 00000000..0a4f2686 --- /dev/null +++ b/src/authenticatorservice.rs @@ -0,0 +1,618 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use std::sync::{mpsc::Sender, Arc, Mutex}; + +use crate::consts::PARAMETER_SIZE; +use crate::statecallback::StateCallback; + +pub trait AuthenticatorTransport { + /// The implementation of this method must return quickly and should + /// report its status via the status and callback methods + fn register( + &mut self, + flags: crate::RegisterFlags, + timeout: u64, + challenge: Vec, + application: crate::AppId, + key_handles: Vec, + status: Sender, + callback: StateCallback>, + ) -> Result<(), crate::Error>; + + /// The implementation of this method must return quickly and should + /// report its status via the status and callback methods + fn sign( + &mut self, + flags: crate::SignFlags, + timeout: u64, + challenge: Vec, + app_ids: Vec, + key_handles: Vec, + status: Sender, + callback: StateCallback>, + ) -> Result<(), crate::Error>; + + fn cancel(&mut self) -> Result<(), crate::Error>; +} + +pub struct AuthenticatorService { + transports: Vec>>>, +} + +fn clone_and_configure_cancellation_callback( + mut callback: StateCallback, + transports_to_cancel: Vec>>>, +) -> StateCallback { + callback.add_uncloneable_observer(Box::new(move || { + debug!( + "Callback observer is running, cancelling \ + {} unchosen transports...", + transports_to_cancel.len() + ); + for transport_mutex in &transports_to_cancel { + if let Err(e) = transport_mutex.lock().unwrap().cancel() { + error!("Cancellation failed: {:?}", e); + } + } + })); + callback +} + +impl AuthenticatorService { + pub fn new() -> Result { + Ok(Self { + transports: Vec::new(), + }) + } + + /// Add any detected platform transports + pub fn add_detected_transports(&mut self) { + self.add_u2f_usb_hid_platform_transports(); + } + + fn add_transport(&mut self, boxed_token: Box) { + self.transports.push(Arc::new(Mutex::new(boxed_token))) + } + + pub fn add_u2f_usb_hid_platform_transports(&mut self) { + match crate::U2FManager::new() { + Ok(token) => self.add_transport(Box::new(token)), + Err(e) => error!("Could not add U2F HID transport: {}", e), + } + } + + pub fn register( + &mut self, + flags: crate::RegisterFlags, + timeout: u64, + challenge: Vec, + application: crate::AppId, + key_handles: Vec, + status: Sender, + callback: StateCallback>, + ) -> Result<(), crate::Error> { + if challenge.len() != PARAMETER_SIZE || application.len() != PARAMETER_SIZE { + return Err(crate::Error::Unknown); + } + + for key_handle in &key_handles { + if key_handle.credential.len() > 256 { + return Err(crate::Error::Unknown); + } + } + + let iterable_transports = self.transports.clone(); + if iterable_transports.is_empty() { + return Err(crate::Error::NotSupported); + } + + debug!( + "register called with {} transports, iterable is {}", + self.transports.len(), + iterable_transports.len() + ); + + for (idx, transport_mutex) in iterable_transports.iter().enumerate() { + let mut transports_to_cancel = iterable_transports.clone(); + transports_to_cancel.remove(idx); + + debug!( + "register transports_to_cancel {}", + transports_to_cancel.len() + ); + + transport_mutex.lock().unwrap().register( + flags.clone(), + timeout, + challenge.clone(), + application.clone(), + key_handles.clone(), + status.clone(), + clone_and_configure_cancellation_callback(callback.clone(), transports_to_cancel), + )?; + } + + Ok(()) + } + + pub fn sign( + &mut self, + flags: crate::SignFlags, + timeout: u64, + challenge: Vec, + app_ids: Vec, + key_handles: Vec, + status: Sender, + callback: StateCallback>, + ) -> Result<(), crate::Error> { + if challenge.len() != PARAMETER_SIZE { + return Err(crate::Error::Unknown); + } + + if app_ids.is_empty() { + return Err(crate::Error::Unknown); + } + + for app_id in &app_ids { + if app_id.len() != PARAMETER_SIZE { + return Err(crate::Error::Unknown); + } + } + + for key_handle in &key_handles { + if key_handle.credential.len() > 256 { + return Err(crate::Error::Unknown); + } + } + + let iterable_transports = self.transports.clone(); + if iterable_transports.is_empty() { + return Err(crate::Error::NotSupported); + } + + for (idx, transport_mutex) in iterable_transports.iter().enumerate() { + let mut transports_to_cancel = iterable_transports.clone(); + transports_to_cancel.remove(idx); + + transport_mutex.lock().unwrap().sign( + flags.clone(), + timeout, + challenge.clone(), + app_ids.clone(), + key_handles.clone(), + status.clone(), + clone_and_configure_cancellation_callback(callback.clone(), transports_to_cancel), + )?; + } + + Ok(()) + } + + pub fn cancel(&mut self) -> Result<(), crate::Error> { + if self.transports.is_empty() { + return Err(crate::Error::NotSupported); + } + + for transport_mutex in &mut self.transports { + transport_mutex.lock().unwrap().cancel()?; + } + + Ok(()) + } +} + +//////////////////////////////////////////////////////////////////////// +// Tests +//////////////////////////////////////////////////////////////////////// + +#[cfg(test)] +mod tests { + use super::{AuthenticatorService, AuthenticatorTransport}; + use crate::consts::PARAMETER_SIZE; + use crate::statecallback::StateCallback; + use crate::{AuthenticatorTransports, KeyHandle, RegisterFlags, SignFlags, StatusUpdate}; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::{io, thread}; + + fn init() { + let _ = env_logger::builder().is_test(true).try_init(); + } + + pub struct TestTransportDriver { + consent: bool, + was_cancelled: Arc, + } + + impl TestTransportDriver { + pub fn new(consent: bool) -> io::Result { + Ok(Self { + consent, + was_cancelled: Arc::new(AtomicBool::new(false)), + }) + } + } + + impl TestTransportDriver { + fn dev_info(&self) -> crate::u2ftypes::U2FDeviceInfo { + crate::u2ftypes::U2FDeviceInfo { + vendor_name: String::from("Mozilla").into_bytes(), + device_name: String::from("Test Transport Token").into_bytes(), + version_interface: 0, + version_major: 1, + version_minor: 2, + version_build: 3, + cap_flags: 0, + } + } + } + + impl AuthenticatorTransport for TestTransportDriver { + fn register( + &mut self, + _flags: crate::RegisterFlags, + _timeout: u64, + _challenge: Vec, + _application: crate::AppId, + _key_handles: Vec, + _status: Sender, + callback: StateCallback>, + ) -> Result<(), crate::Error> { + if self.consent { + let rv = Ok((vec![0u8; 16], self.dev_info())); + thread::spawn(move || callback.call(rv)); + } + Ok(()) + } + + fn sign( + &mut self, + _flags: crate::SignFlags, + _timeout: u64, + _challenge: Vec, + _app_ids: Vec, + _key_handles: Vec, + _status: Sender, + callback: StateCallback>, + ) -> Result<(), crate::Error> { + if self.consent { + let rv = Ok((vec![0u8; 0], vec![0u8; 0], vec![0u8; 0], self.dev_info())); + thread::spawn(move || callback.call(rv)); + } + Ok(()) + } + + fn cancel(&mut self) -> Result<(), crate::Error> { + self.was_cancelled + .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst) + .map_or(Err(crate::Error::InvalidState), |_| Ok(())) + } + } + + fn mk_key() -> KeyHandle { + KeyHandle { + credential: vec![0], + transports: AuthenticatorTransports::USB, + } + } + + fn mk_challenge() -> Vec { + vec![0x11; PARAMETER_SIZE] + } + + fn mk_appid() -> Vec { + vec![0x22; PARAMETER_SIZE] + } + + #[test] + fn test_no_challenge() { + init(); + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + s.add_transport(Box::new(TestTransportDriver::new(true).unwrap())); + + assert_eq!( + s.register( + RegisterFlags::empty(), + 1_000, + vec![], + mk_appid(), + vec![mk_key()], + status_tx.clone(), + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::Unknown + ); + + assert_eq!( + s.sign( + SignFlags::empty(), + 1_000, + vec![], + vec![mk_appid()], + vec![mk_key()], + status_tx, + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::Unknown + ); + } + + #[test] + fn test_no_appids() { + init(); + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + s.add_transport(Box::new(TestTransportDriver::new(true).unwrap())); + + assert_eq!( + s.register( + RegisterFlags::empty(), + 1_000, + mk_challenge(), + vec![], + vec![mk_key()], + status_tx.clone(), + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::Unknown + ); + + assert_eq!( + s.sign( + SignFlags::empty(), + 1_000, + mk_challenge(), + vec![], + vec![mk_key()], + status_tx, + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::Unknown + ); + } + + #[test] + fn test_no_keys() { + init(); + // No Keys is a resident-key use case. For U2F this would time out, + // but the actual reactions are up to the service implementation. + // This test yields OKs. + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + s.add_transport(Box::new(TestTransportDriver::new(true).unwrap())); + + assert_eq!( + s.register( + RegisterFlags::empty(), + 100, + mk_challenge(), + mk_appid(), + vec![], + status_tx.clone(), + StateCallback::new(Box::new(move |_rv| {})), + ), + Ok(()) + ); + + assert_eq!( + s.sign( + SignFlags::empty(), + 100, + mk_challenge(), + vec![mk_appid()], + vec![], + status_tx, + StateCallback::new(Box::new(move |_rv| {})), + ), + Ok(()) + ); + } + + #[test] + fn test_large_keys() { + init(); + let (status_tx, _) = channel::(); + + let large_key = KeyHandle { + credential: vec![0; 257], + transports: AuthenticatorTransports::USB, + }; + + let mut s = AuthenticatorService::new().unwrap(); + s.add_transport(Box::new(TestTransportDriver::new(true).unwrap())); + + assert_eq!( + s.register( + RegisterFlags::empty(), + 1_000, + mk_challenge(), + mk_appid(), + vec![large_key.clone()], + status_tx.clone(), + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::Unknown + ); + + assert_eq!( + s.sign( + SignFlags::empty(), + 1_000, + mk_challenge(), + vec![mk_appid()], + vec![large_key], + status_tx, + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::Unknown + ); + } + + #[test] + fn test_no_transports() { + init(); + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + assert_eq!( + s.register( + RegisterFlags::empty(), + 1_000, + mk_challenge(), + mk_appid(), + vec![mk_key()], + status_tx.clone(), + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::NotSupported + ); + + assert_eq!( + s.sign( + SignFlags::empty(), + 1_000, + mk_challenge(), + vec![mk_appid()], + vec![mk_key()], + status_tx, + StateCallback::new(Box::new(move |_rv| {})), + ) + .unwrap_err(), + crate::Error::NotSupported + ); + + assert_eq!(s.cancel().unwrap_err(), crate::Error::NotSupported); + } + + #[test] + fn test_cancellation_register() { + init(); + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + let ttd_one = TestTransportDriver::new(true).unwrap(); + let ttd_two = TestTransportDriver::new(false).unwrap(); + let ttd_three = TestTransportDriver::new(false).unwrap(); + + let was_cancelled_one = ttd_one.was_cancelled.clone(); + let was_cancelled_two = ttd_two.was_cancelled.clone(); + let was_cancelled_three = ttd_three.was_cancelled.clone(); + + s.add_transport(Box::new(ttd_one)); + s.add_transport(Box::new(ttd_two)); + s.add_transport(Box::new(ttd_three)); + + let callback = StateCallback::new(Box::new(move |_rv| {})); + assert_eq!( + s.register( + RegisterFlags::empty(), + 1_000, + mk_challenge(), + mk_appid(), + vec![], + status_tx.clone(), + callback.clone(), + ), + Ok(()) + ); + callback.wait(); + + assert_eq!(was_cancelled_one.load(Ordering::SeqCst), false); + assert_eq!(was_cancelled_two.load(Ordering::SeqCst), true); + assert_eq!(was_cancelled_three.load(Ordering::SeqCst), true); + } + + #[test] + fn test_cancellation_sign() { + init(); + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + let ttd_one = TestTransportDriver::new(true).unwrap(); + let ttd_two = TestTransportDriver::new(false).unwrap(); + let ttd_three = TestTransportDriver::new(false).unwrap(); + + let was_cancelled_one = ttd_one.was_cancelled.clone(); + let was_cancelled_two = ttd_two.was_cancelled.clone(); + let was_cancelled_three = ttd_three.was_cancelled.clone(); + + s.add_transport(Box::new(ttd_one)); + s.add_transport(Box::new(ttd_two)); + s.add_transport(Box::new(ttd_three)); + + let callback = StateCallback::new(Box::new(move |_rv| {})); + assert_eq!( + s.sign( + SignFlags::empty(), + 1_000, + mk_challenge(), + vec![mk_appid()], + vec![mk_key()], + status_tx, + callback.clone(), + ), + Ok(()) + ); + callback.wait(); + + assert_eq!(was_cancelled_one.load(Ordering::SeqCst), false); + assert_eq!(was_cancelled_two.load(Ordering::SeqCst), true); + assert_eq!(was_cancelled_three.load(Ordering::SeqCst), true); + } + + #[test] + fn test_cancellation_race() { + init(); + let (status_tx, _) = channel::(); + + let mut s = AuthenticatorService::new().unwrap(); + // Let both of these race which one provides consent. + let ttd_one = TestTransportDriver::new(true).unwrap(); + let ttd_two = TestTransportDriver::new(true).unwrap(); + + let was_cancelled_one = ttd_one.was_cancelled.clone(); + let was_cancelled_two = ttd_two.was_cancelled.clone(); + + s.add_transport(Box::new(ttd_one)); + s.add_transport(Box::new(ttd_two)); + + let callback = StateCallback::new(Box::new(move |_rv| {})); + assert_eq!( + s.register( + RegisterFlags::empty(), + 1_000, + mk_challenge(), + mk_appid(), + vec![], + status_tx.clone(), + callback.clone(), + ), + Ok(()) + ); + callback.wait(); + + let one = was_cancelled_one.load(Ordering::SeqCst); + let two = was_cancelled_two.load(Ordering::SeqCst); + assert_eq!( + one ^ two, + true, + "asserting that one={} xor two={} is true", + one, + two + ); + } +} diff --git a/src/capi.rs b/src/capi.rs index c239ad44..3a29a9ed 100644 --- a/src/capi.rs +++ b/src/capi.rs @@ -2,14 +2,15 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +use crate::authenticatorservice::AuthenticatorService; +use crate::statecallback::StateCallback; +use crate::{RegisterResult, SignResult}; use libc::size_t; use rand::{thread_rng, Rng}; use std::collections::HashMap; use std::sync::mpsc::channel; use std::{ptr, slice}; -use crate::U2FManager; - type U2FAppIds = Vec; type U2FKeyHandles = Vec; type U2FCallback = extern "C" fn(u64, *mut U2FResult); @@ -42,8 +43,9 @@ unsafe fn from_raw(ptr: *const u8, len: usize) -> Vec { /// /// The handle returned by this method must be freed by the caller. #[no_mangle] -pub extern "C" fn rust_u2f_mgr_new() -> *mut U2FManager { - if let Ok(mgr) = U2FManager::new() { +pub extern "C" fn rust_u2f_mgr_new() -> *mut AuthenticatorService { + if let Ok(mut mgr) = AuthenticatorService::new() { + mgr.add_detected_transports(); Box::into_raw(Box::new(mgr)) } else { ptr::null_mut() @@ -55,7 +57,7 @@ pub extern "C" fn rust_u2f_mgr_new() -> *mut U2FManager { /// This method must not be called on a handle twice, and the handle is unusable /// after. #[no_mangle] -pub unsafe extern "C" fn rust_u2f_mgr_free(mgr: *mut U2FManager) { +pub unsafe extern "C" fn rust_u2f_mgr_free(mgr: *mut AuthenticatorService) { if !mgr.is_null() { Box::from_raw(mgr); } @@ -192,8 +194,8 @@ pub unsafe extern "C" fn rust_u2f_resbuf_copy( /// # Safety /// -/// This method should not be called U2FManager handles after they've been freed -/// or a double-free will occur +/// This method should not be called on U2FResult handles after they've been +/// freed or a double-free will occur #[no_mangle] pub unsafe extern "C" fn rust_u2f_res_free(res: *mut U2FResult) { if !res.is_null() { @@ -203,10 +205,11 @@ pub unsafe extern "C" fn rust_u2f_res_free(res: *mut U2FResult) { /// # Safety /// -/// This method should not be called U2FManager handles after they've been freed +/// This method should not be called on AuthenticatorService handles after +/// they've been freed #[no_mangle] pub unsafe extern "C" fn rust_u2f_mgr_register( - mgr: *mut U2FManager, + mgr: *mut AuthenticatorService, flags: u64, timeout: u64, callback: U2FCallback, @@ -231,16 +234,10 @@ pub unsafe extern "C" fn rust_u2f_mgr_register( let key_handles = (*khs).clone(); let (tx, _rx) = channel::(); - let tid = new_tid(); - let res = (*mgr).register( - flags, - timeout, - challenge, - application, - key_handles, - tx, - move |rv| { + + let state_callback = + StateCallback::>::new(Box::new(move |rv| { let result = match rv { Ok((registration, dev_info)) => { let mut bufs = HashMap::new(); @@ -256,7 +253,16 @@ pub unsafe extern "C" fn rust_u2f_mgr_register( }; callback(tid, Box::into_raw(Box::new(result))); - }, + })); + + let res = (*mgr).register( + flags, + timeout, + challenge, + application, + key_handles, + tx, + state_callback, ); if res.is_ok() { @@ -268,10 +274,11 @@ pub unsafe extern "C" fn rust_u2f_mgr_register( /// # Safety /// -/// This method should not be called U2FManager handles after they've been freed +/// This method should not be called on AuthenticatorService handles after +/// they've been freed #[no_mangle] pub unsafe extern "C" fn rust_u2f_mgr_sign( - mgr: *mut U2FManager, + mgr: *mut AuthenticatorService, flags: u64, timeout: u64, callback: U2FCallback, @@ -302,14 +309,8 @@ pub unsafe extern "C" fn rust_u2f_mgr_sign( let (tx, _rx) = channel::(); let tid = new_tid(); - let res = (*mgr).sign( - flags, - timeout, - challenge, - app_ids, - key_handles, - tx, - move |rv| { + let state_callback = + StateCallback::>::new(Box::new(move |rv| { let result = match rv { Ok((app_id, key_handle, signature, dev_info)) => { let mut bufs = HashMap::new(); @@ -327,7 +328,16 @@ pub unsafe extern "C" fn rust_u2f_mgr_sign( }; callback(tid, Box::into_raw(Box::new(result))); - }, + })); + + let res = (*mgr).sign( + flags, + timeout, + challenge, + app_ids, + key_handles, + tx, + state_callback, ); if res.is_ok() { @@ -339,9 +349,10 @@ pub unsafe extern "C" fn rust_u2f_mgr_sign( /// # Safety /// -/// This method should not be called U2FManager handles after they've been freed +/// This method should not be called AuthenticatorService handles after they've +/// been freed #[no_mangle] -pub unsafe extern "C" fn rust_u2f_mgr_cancel(mgr: *mut U2FManager) { +pub unsafe extern "C" fn rust_u2f_mgr_cancel(mgr: *mut AuthenticatorService) { if !mgr.is_null() { // Ignore return value. let _ = (*mgr).cancel(); diff --git a/src/freebsd/transaction.rs b/src/freebsd/transaction.rs index cfdc54b9..26858510 100644 --- a/src/freebsd/transaction.rs +++ b/src/freebsd/transaction.rs @@ -3,7 +3,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::platform::monitor::Monitor; -use crate::util::StateCallback; +use crate::statecallback::StateCallback; use runloop::RunLoop; use std::ffi::OsString; diff --git a/src/lib.rs b/src/lib.rs index 1afaee3e..ca0e0710 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ extern crate runloop; #[macro_use] extern crate bitflags; +pub mod authenticatorservice; mod consts; mod statemachine; mod u2fprotocol; @@ -72,6 +73,8 @@ pub use crate::manager::U2FManager; mod capi; pub use crate::capi::*; +pub mod statecallback; + // Keep this in sync with the constants in u2fhid-capi.h. bitflags! { pub struct RegisterFlags: u64 { @@ -103,7 +106,7 @@ pub type AppId = Vec; pub type RegisterResult = (Vec, u2ftypes::U2FDeviceInfo); pub type SignResult = (AppId, Vec, Vec, u2ftypes::U2FDeviceInfo); -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Error { Unknown = 1, NotSupported = 2, @@ -112,6 +115,14 @@ pub enum Error { NotAllowed = 5, } +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl std::error::Error for Error {} + #[derive(Debug, Clone)] pub enum StatusUpdate { DeviceAvailable { dev_info: u2ftypes::U2FDeviceInfo }, diff --git a/src/linux/transaction.rs b/src/linux/transaction.rs index cfdc54b9..26858510 100644 --- a/src/linux/transaction.rs +++ b/src/linux/transaction.rs @@ -3,7 +3,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::platform::monitor::Monitor; -use crate::util::StateCallback; +use crate::statecallback::StateCallback; use runloop::RunLoop; use std::ffi::OsString; diff --git a/src/macos/transaction.rs b/src/macos/transaction.rs index 7b43a558..f523fc43 100644 --- a/src/macos/transaction.rs +++ b/src/macos/transaction.rs @@ -6,7 +6,7 @@ extern crate libc; use crate::platform::iokit::{CFRunLoopEntryObserver, IOHIDDeviceRef, SendableRunLoop}; use crate::platform::monitor::Monitor; -use crate::util::StateCallback; +use crate::statecallback::StateCallback; use core_foundation::runloop::*; use std::os::raw::c_void; use std::sync::mpsc::{channel, Receiver, Sender}; diff --git a/src/manager.rs b/src/manager.rs index 76acc9d6..06c2351c 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -6,9 +6,10 @@ use std::io; use std::sync::mpsc::{channel, RecvTimeoutError, Sender}; use std::time::Duration; +use crate::authenticatorservice::AuthenticatorTransport; use crate::consts::PARAMETER_SIZE; +use crate::statecallback::StateCallback; use crate::statemachine::StateMachine; -use crate::util::StateCallback; use runloop::RunLoop; enum QueueAction { @@ -106,21 +107,19 @@ impl U2FManager { Ok(Self { queue, tx }) } +} - pub fn register( - &self, +impl AuthenticatorTransport for U2FManager { + fn register( + &mut self, flags: crate::RegisterFlags, timeout: u64, challenge: Vec, application: crate::AppId, key_handles: Vec, status: Sender, - callback: F, - ) -> Result<(), crate::Error> - where - F: Fn(Result), - F: Send + 'static, - { + callback: StateCallback>, + ) -> Result<(), crate::Error> { if challenge.len() != PARAMETER_SIZE || application.len() != PARAMETER_SIZE { return Err(crate::Error::Unknown); } @@ -131,7 +130,6 @@ impl U2FManager { } } - let callback = StateCallback::new(Box::new(callback)); let action = QueueAction::Register { flags, timeout, @@ -144,20 +142,16 @@ impl U2FManager { self.tx.send(action).map_err(|_| crate::Error::Unknown) } - pub fn sign( - &self, + fn sign( + &mut self, flags: crate::SignFlags, timeout: u64, challenge: Vec, app_ids: Vec, key_handles: Vec, status: Sender, - callback: F, - ) -> Result<(), crate::Error> - where - F: Fn(Result), - F: Send + 'static, - { + callback: StateCallback>, + ) -> Result<(), crate::Error> { if challenge.len() != PARAMETER_SIZE { return Err(crate::Error::Unknown); } @@ -178,7 +172,6 @@ impl U2FManager { } } - let callback = StateCallback::new(Box::new(callback)); let action = QueueAction::Sign { flags, timeout, @@ -191,7 +184,7 @@ impl U2FManager { self.tx.send(action).map_err(|_| crate::Error::Unknown) } - pub fn cancel(&self) -> Result<(), crate::Error> { + fn cancel(&mut self) -> Result<(), crate::Error> { self.tx .send(QueueAction::Cancel) .map_err(|_| crate::Error::Unknown) diff --git a/src/netbsd/transaction.rs b/src/netbsd/transaction.rs index c3710219..648cc8d7 100644 --- a/src/netbsd/transaction.rs +++ b/src/netbsd/transaction.rs @@ -2,7 +2,7 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use crate::util::StateCallback; +use crate::statecallback::StateCallback; use runloop::RunLoop; use crate::platform::fd::Fd; diff --git a/src/openbsd/transaction.rs b/src/openbsd/transaction.rs index 8fd7618d..350a1283 100644 --- a/src/openbsd/transaction.rs +++ b/src/openbsd/transaction.rs @@ -3,7 +3,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::platform::monitor::{FidoDev, Monitor}; -use crate::util::StateCallback; +use crate::statecallback::StateCallback; use runloop::RunLoop; pub struct Transaction { diff --git a/src/statecallback.rs b/src/statecallback.rs new file mode 100644 index 00000000..14797899 --- /dev/null +++ b/src/statecallback.rs @@ -0,0 +1,162 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use std::sync::{Arc, Condvar, Mutex}; + +pub struct StateCallback { + callback: Arc>>>, + observer: Arc>>>, + condition: Arc<(Mutex, Condvar)>, +} + +impl StateCallback { + pub fn new(cb: Box) -> Self { + Self { + callback: Arc::new(Mutex::new(Some(cb))), + observer: Arc::new(Mutex::new(None)), + condition: Arc::new((Mutex::new(true), Condvar::new())), + } + } + + pub fn add_uncloneable_observer(&mut self, obs: Box) { + let mut opt = self.observer.lock().unwrap(); + if opt.is_some() { + error!("Replacing an already-set observer.") + } + opt.replace(obs); + } + + pub fn call(&self, rv: T) { + if let Some(cb) = self.callback.lock().unwrap().take() { + cb(rv); + + if let Some(obs) = self.observer.lock().unwrap().take() { + obs(); + } + } + + let (lock, cvar) = &*self.condition; + let mut pending = lock.lock().unwrap(); + *pending = false; + cvar.notify_all(); + } + + pub fn wait(&self) { + let (lock, cvar) = &*self.condition; + let _useless_guard = cvar + .wait_while(lock.lock().unwrap(), |pending| *pending) + .unwrap(); + } +} + +impl Clone for StateCallback { + fn clone(&self) -> Self { + Self { + callback: self.callback.clone(), + observer: Arc::new(Mutex::new(None)), + condition: self.condition.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::StateCallback; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::{Arc, Barrier}; + use std::thread; + + #[test] + fn test_statecallback_is_single_use() { + let counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = counter.clone(); + let sc = StateCallback::new(Box::new(move |_| { + counter_clone.fetch_add(1, Ordering::SeqCst); + })); + + assert_eq!(counter.load(Ordering::SeqCst), 0); + for _ in 0..10 { + sc.call(()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + for _ in 0..10 { + sc.clone().call(()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + } + + #[test] + fn test_statecallback_observer_is_single_use() { + let counter = Arc::new(AtomicUsize::new(0)); + let counter_clone = counter.clone(); + let mut sc = StateCallback::<()>::new(Box::new(move |_| {})); + + sc.add_uncloneable_observer(Box::new(move || { + counter_clone.fetch_add(1, Ordering::SeqCst); + })); + + assert_eq!(counter.load(Ordering::SeqCst), 0); + for _ in 0..10 { + sc.call(()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + + for _ in 0..10 { + sc.clone().call(()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + } + } + + #[test] + fn test_statecallback_observer_only_runs_for_completing_callback() { + let cb_counter = Arc::new(AtomicUsize::new(0)); + let cb_counter_clone = cb_counter.clone(); + let sc = StateCallback::new(Box::new(move |_| { + cb_counter_clone.fetch_add(1, Ordering::SeqCst); + })); + + let obs_counter = Arc::new(AtomicUsize::new(0)); + + for _ in 0..10 { + let obs_counter_clone = obs_counter.clone(); + let mut c = sc.clone(); + c.add_uncloneable_observer(Box::new(move || { + obs_counter_clone.fetch_add(1, Ordering::SeqCst); + })); + + c.call(()); + + assert_eq!(cb_counter.load(Ordering::SeqCst), 1); + assert_eq!(obs_counter.load(Ordering::SeqCst), 1); + } + } + + #[test] + fn test_statecallback_observer_unclonable() { + let mut sc = StateCallback::<()>::new(Box::new(move |_| {})); + sc.add_uncloneable_observer(Box::new(move || {})); + + assert!(sc.observer.lock().unwrap().is_some()); + assert!(sc.clone().observer.lock().unwrap().is_none()); + } + + #[test] + fn test_statecallback_wait() { + let sc = StateCallback::<()>::new(Box::new(move |_| {})); + let barrier = Arc::new(Barrier::new(2)); + + { + let c = sc.clone(); + let b = barrier.clone(); + thread::spawn(move || { + b.wait(); + c.call(()); + }); + } + + barrier.wait(); + sc.wait(); + } +} diff --git a/src/statemachine.rs b/src/statemachine.rs index b8d62e4a..a8118ed7 100644 --- a/src/statemachine.rs +++ b/src/statemachine.rs @@ -5,9 +5,10 @@ use crate::consts::PARAMETER_SIZE; use crate::platform::device::Device; use crate::platform::transaction::Transaction; +use crate::statecallback::StateCallback; use crate::u2fprotocol::{u2f_init_device, u2f_is_keyhandle_valid, u2f_register, u2f_sign}; use crate::u2ftypes::U2FDevice; -use crate::util::StateCallback; + use std::sync::mpsc::Sender; use std::sync::Mutex; use std::thread; diff --git a/src/stub/transaction.rs b/src/stub/transaction.rs index b697d54b..5febbe4a 100644 --- a/src/stub/transaction.rs +++ b/src/stub/transaction.rs @@ -2,7 +2,7 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ -use crate::util::StateCallback; +use crate::statecallback::StateCallback; pub struct Transaction {} diff --git a/src/util.rs b/src/util.rs index 80bdfc6a..4ccb3c97 100644 --- a/src/util.rs +++ b/src/util.rs @@ -5,7 +5,6 @@ extern crate libc; use std::io; -use std::sync::{Arc, Mutex}; macro_rules! try_or { ($val:expr, $or:expr) => { @@ -66,31 +65,3 @@ pub fn from_unix_result(rv: T) -> io::Result { pub fn io_err(msg: &str) -> io::Error { io::Error::new(io::ErrorKind::Other, msg) } - -pub struct StateCallback { - callback: Arc>>>, -} - -impl StateCallback { - pub fn new(cb: Box) -> Self { - Self { - callback: Arc::new(Mutex::new(Some(cb))), - } - } - - pub fn call(&self, rv: T) { - if let Ok(mut cb) = self.callback.lock() { - if let Some(cb) = cb.take() { - cb(rv); - } - } - } -} - -impl Clone for StateCallback { - fn clone(&self) -> Self { - Self { - callback: self.callback.clone(), - } - } -} diff --git a/src/windows/transaction.rs b/src/windows/transaction.rs index 96c1996a..de738134 100644 --- a/src/windows/transaction.rs +++ b/src/windows/transaction.rs @@ -3,7 +3,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use crate::platform::monitor::Monitor; -use crate::util::StateCallback; +use crate::statecallback::StateCallback; use runloop::RunLoop; pub struct Transaction {