Skip to content

Commit f92aee8

Browse files
authored
feat: add additional application context into Connection (#5637)
1 parent a23eb49 commit f92aee8

File tree

2 files changed

+81
-28
lines changed

2 files changed

+81
-28
lines changed

bindings/rust/extended/s2n-tls/src/connection.rs

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ use core::{
2828
};
2929
use libc::c_void;
3030
use s2n_tls_sys::*;
31-
use std::{any::Any, ffi::CStr};
31+
use std::{
32+
any::{Any, TypeId},
33+
collections::HashMap,
34+
ffi::CStr,
35+
};
3236

3337
mod builder;
3438
pub use builder::*;
@@ -1408,49 +1412,63 @@ impl Connection {
14081412
Ok(())
14091413
}
14101414

1411-
/// Associates an arbitrary application context with the Connection to be later retrieved via
1415+
/// Associates arbitrary application contexts with the Connection to be later retrieved via
14121416
/// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs.
14131417
///
1414-
/// This API will override an existing application context set on the Connection.
1418+
/// While multiple application contexts of different types may be set, previous values of the same type will be overridden.
14151419
///
14161420
/// Corresponds to [s2n_connection_set_ctx].
14171421
pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
1418-
self.context_mut().app_context = Some(Box::new(app_context));
1422+
let context_type_id = TypeId::of::<T>();
1423+
self.context_mut()
1424+
.app_context
1425+
.insert(context_type_id, Box::new(app_context));
1426+
}
1427+
1428+
/// Removes an application context set on the Connection.
1429+
///
1430+
/// Returns Some containing the removed context if it exists, or None if no context
1431+
/// of the specified type was previously set.
1432+
pub fn remove_application_context<T: Send + Sync + 'static>(
1433+
&mut self,
1434+
) -> Option<Box<dyn Any + Send + Sync>> {
1435+
let context_type_id = TypeId::of::<T>();
1436+
self.context_mut().app_context.remove(&context_type_id)
14191437
}
14201438

14211439
/// Retrieves a reference to the application context associated with the Connection.
14221440
///
1423-
/// If an application context hasn't already been set on the Connection, or if the set
1424-
/// application context isn't of type T, None will be returned.
1441+
/// Returns None if the provided type T does not match the type of any application context set on the Connection.
14251442
///
14261443
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a
14271444
/// mutable reference to the context, use [`Self::application_context_mut()`].
14281445
///
14291446
/// Corresponds to [s2n_connection_get_ctx].
14301447
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
1431-
match self.context().app_context.as_ref() {
1432-
None => None,
1433-
// The Any trait keeps track of the application context's type. downcast_ref() returns
1434-
// Some only if the correct type is provided:
1435-
// https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
1436-
Some(app_context) => app_context.downcast_ref::<T>(),
1437-
}
1448+
let context_type_id = TypeId::of::<T>();
1449+
// The Any trait keeps track of the application context's type. downcast_ref() returns
1450+
// Some only if the correct type is provided:
1451+
// https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
1452+
self.context()
1453+
.app_context
1454+
.get(&context_type_id)
1455+
.and_then(|app_context| app_context.downcast_ref::<T>())
14381456
}
14391457

14401458
/// Retrieves a mutable reference to the application context associated with the Connection.
14411459
///
1442-
/// If an application context hasn't already been set on the Connection, or if the set
1443-
/// application context isn't of type T, None will be returned.
1460+
/// Returns None if the provided type T does not match the type of any application context set on the Connection.
14441461
///
14451462
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an
14461463
/// immutable reference to the context, use [`Self::application_context()`].
14471464
///
14481465
/// Corresponds to [s2n_connection_get_ctx].
14491466
pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
1450-
match self.context_mut().app_context.as_mut() {
1451-
None => None,
1452-
Some(app_context) => app_context.downcast_mut::<T>(),
1453-
}
1467+
let context_type_id = TypeId::of::<T>();
1468+
self.context_mut()
1469+
.app_context
1470+
.get_mut(&context_type_id)
1471+
.and_then(|app_context| app_context.downcast_mut::<T>())
14541472
}
14551473

14561474
#[cfg(feature = "unstable-cert_authorities")]
@@ -1475,7 +1493,7 @@ struct Context {
14751493
async_callback: Option<AsyncCallback>,
14761494
verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
14771495
connection_initialized: bool,
1478-
app_context: Option<Box<dyn Any + Send + Sync>>,
1496+
app_context: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
14791497
#[cfg(feature = "unstable-renegotiate")]
14801498
pub(crate) renegotiate_state: RenegotiateState,
14811499
#[cfg(feature = "unstable-cert_authorities")]
@@ -1490,7 +1508,7 @@ impl Context {
14901508
async_callback: None,
14911509
verify_host_callback: None,
14921510
connection_initialized: false,
1493-
app_context: None,
1511+
app_context: HashMap::new(),
14941512
#[cfg(feature = "unstable-renegotiate")]
14951513
renegotiate_state: RenegotiateState::default(),
14961514
#[cfg(feature = "unstable-cert_authorities")]
@@ -1602,6 +1620,7 @@ impl Drop for Connection {
16021620
mod tests {
16031621
use super::*;
16041622
use crate::testing::{build_config, SniTestCerts, TestPair};
1623+
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
16051624

16061625
// ensure the connection context is send
16071626
#[test]
@@ -1622,10 +1641,11 @@ mod tests {
16221641
fn test_app_context_set_and_retrieve() {
16231642
let mut connection = Connection::new_server();
16241643

1644+
let test_value: u32 = 1142;
1645+
16251646
// Before a context is set, None is returned.
16261647
assert!(connection.application_context::<u32>().is_none());
16271648

1628-
let test_value: u32 = 1142;
16291649
connection.set_application_context(test_value);
16301650

16311651
// After a context is set, the application data is returned.
@@ -1669,6 +1689,39 @@ mod tests {
16691689
assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
16701690
}
16711691

1692+
/// Test that multiple application contexts can be set in a connection
1693+
#[test]
1694+
fn test_multiple_app_contexts() {
1695+
let mut connection = Connection::new_server();
1696+
1697+
let first_test_value: u16 = 1142;
1698+
connection.set_application_context(first_test_value);
1699+
1700+
assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);
1701+
1702+
// Insert the second application context to the connection
1703+
let second_test_value: SocketAddr =
1704+
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
1705+
connection.set_application_context(second_test_value);
1706+
1707+
assert_eq!(
1708+
*connection.application_context::<SocketAddr>().unwrap(),
1709+
second_test_value
1710+
);
1711+
1712+
// Remove the second application context
1713+
assert_eq!(
1714+
second_test_value,
1715+
*connection
1716+
.remove_application_context::<SocketAddr>()
1717+
.unwrap()
1718+
.downcast::<SocketAddr>()
1719+
.unwrap()
1720+
);
1721+
1722+
assert!(connection.application_context::<SocketAddr>().is_none());
1723+
}
1724+
16721725
/// Test that a context of another type can't be retrieved.
16731726
#[test]
16741727
fn test_app_context_invalid_type() {

bindings/rust/standard/integration/tests/memory.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,12 @@ mod memory_test {
274274
/// lifecycle. The static memory row is an absolute measurement, not a diff.
275275
fn assert_expected(&self) {
276276
const EXPECTED_MEMORY: &[(Lifecycle, usize)] = &[
277-
(Lifecycle::ConnectionInit, 61_482),
278-
(Lifecycle::AfterClientHello, 88_302),
279-
(Lifecycle::AfterServerHello, 116_669),
280-
(Lifecycle::AfterClientFinished, 107_976),
281-
(Lifecycle::HandshakeComplete, 90_563),
282-
(Lifecycle::ApplicationData, 90_563),
277+
(Lifecycle::ConnectionInit, 61_578),
278+
(Lifecycle::AfterClientHello, 88_406),
279+
(Lifecycle::AfterServerHello, 116_773),
280+
(Lifecycle::AfterClientFinished, 108_080),
281+
(Lifecycle::HandshakeComplete, 90_667),
282+
(Lifecycle::ApplicationData, 90_667),
283283
];
284284
let actual_memory: Vec<(Lifecycle, usize)> = Lifecycle::all_stages()
285285
.into_iter()

0 commit comments

Comments
 (0)