Skip to content

Commit 822caf4

Browse files
sanityclaudegithub-actions[bot]iduartgomez
authored
fix: implement token expiration mechanism for attested contracts (#1976)
Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: nacho.d.g <iduartgomez@users.noreply.github.com> Co-authored-by: nacho.d.g <iduartgomez@gmail.com>
1 parent 624af09 commit 822caf4

File tree

13 files changed

+517
-56
lines changed

13 files changed

+517
-56
lines changed

apps/freenet-ping/app/tests/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ pub async fn base_node_test_config_with_rng(
112112
ws_api: WebsocketApiArgs {
113113
address: Some(Ipv4Addr::LOCALHOST.into()),
114114
ws_api_port: Some(ws_api_port),
115+
token_ttl_seconds: None,
116+
token_cleanup_interval_seconds: None,
115117
},
116118
network_api: NetworkArgs {
117119
public_address: Some(Ipv4Addr::LOCALHOST.into()),

crates/core/src/client_events/websocket.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use std::{
22
collections::{HashMap, VecDeque},
3-
sync::{Arc, OnceLock, RwLock},
3+
sync::{Arc, OnceLock},
44
time::Duration,
55
};
66

7+
use dashmap::DashMap;
8+
79
use axum::{
810
extract::{
911
ws::{Message, WebSocket},
@@ -43,7 +45,7 @@ impl std::ops::Deref for WebSocketRequest {
4345
}
4446
}
4547

46-
pub(crate) struct WebSocketProxy {
48+
pub struct WebSocketProxy {
4749
proxy_server_request: mpsc::Receiver<ClientConnection>,
4850
response_channels: HashMap<ClientId, mpsc::UnboundedSender<HostCallbackResult>>,
4951
}
@@ -53,10 +55,7 @@ const PARALLELISM: usize = 10; // TODO: get this from config, or whatever optima
5355
impl WebSocketProxy {
5456
pub fn create_router(server_routing: Router) -> (Self, Router) {
5557
// Create a default empty attested contracts map
56-
let attested_contracts = Arc::new(RwLock::new(HashMap::<
57-
AuthToken,
58-
(ContractInstanceId, ClientId),
59-
>::new()));
58+
let attested_contracts = Arc::new(DashMap::new());
6059
Self::create_router_with_attested_contracts(server_routing, attested_contracts)
6160
}
6261

@@ -290,27 +289,27 @@ async fn websocket_commands(
290289
Extension(attested_contracts): Extension<AttestedContractMap>,
291290
) -> Response {
292291
let on_upgrade = move |ws: WebSocket| async move {
293-
// Get the data we need and immediately drop the lock
292+
// Get the data we need from the DashMap
294293
let auth_and_instance = if let Some(token) = auth_token.as_ref() {
295-
let attested_contracts_read = attested_contracts.read().unwrap();
296-
297294
// Only collect and log map contents when trace is enabled
298295
if tracing::enabled!(tracing::Level::TRACE) {
299-
let map_contents: Vec<_> = attested_contracts_read.keys().cloned().collect();
296+
let map_contents: Vec<_> =
297+
attested_contracts.iter().map(|e| e.key().clone()).collect();
300298
tracing::trace!(?token, "attested_contracts map keys: {:?}", map_contents);
301299
}
302300

303-
if let Some((cid, _)) = attested_contracts_read.get(token) {
304-
tracing::trace!(?token, ?cid, "Found token in attested_contracts map");
305-
Some((token.clone(), *cid))
301+
if let Some(entry) = attested_contracts.get(token) {
302+
let attested = entry.value();
303+
tracing::trace!(?token, contract_id = ?attested.contract_id, "Found token in attested_contracts map");
304+
Some((token.clone(), attested.contract_id))
306305
} else {
307306
tracing::warn!(?token, "Auth token not found in attested_contracts map");
308307
None
309308
}
310309
} else {
311310
tracing::trace!("No auth token provided in WebSocket request");
312311
None
313-
}; // RwLockReadGuard is dropped here
312+
};
314313

315314
// Only evaluate auth_and_instance for trace when trace is enabled
316315
if tracing::enabled!(tracing::Level::TRACE) {

crates/core/src/config/mod.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ impl Default for ConfigArgs {
101101
ws_api: WebsocketApiArgs {
102102
address: Some(default_listening_address()),
103103
ws_api_port: Some(default_http_gateway_port()),
104+
token_ttl_seconds: None,
105+
token_cleanup_interval_seconds: None,
104106
},
105107
secrets: Default::default(),
106108
log_level: Some(tracing::log::LevelFilter::Info),
@@ -230,6 +232,12 @@ impl ConfigArgs {
230232
self.mode.get_or_insert(cfg.mode);
231233
self.ws_api.address.get_or_insert(cfg.ws_api.address);
232234
self.ws_api.ws_api_port.get_or_insert(cfg.ws_api.port);
235+
self.ws_api
236+
.token_ttl_seconds
237+
.get_or_insert(cfg.ws_api.token_ttl_seconds);
238+
self.ws_api
239+
.token_cleanup_interval_seconds
240+
.get_or_insert(cfg.ws_api.token_cleanup_interval_seconds);
233241
self.log_level.get_or_insert(cfg.log_level);
234242
self.config_paths.merge(cfg.config_paths.as_ref().clone());
235243
}
@@ -361,6 +369,14 @@ impl ConfigArgs {
361369
.ws_api
362370
.ws_api_port
363371
.unwrap_or(default_http_gateway_port()),
372+
token_ttl_seconds: self
373+
.ws_api
374+
.token_ttl_seconds
375+
.unwrap_or(default_token_ttl_seconds()),
376+
token_cleanup_interval_seconds: self
377+
.ws_api
378+
.token_cleanup_interval_seconds
379+
.unwrap_or(default_token_cleanup_interval_seconds()),
364380
},
365381
secrets,
366382
log_level: self.log_level.unwrap_or(tracing::log::LevelFilter::Info),
@@ -616,6 +632,19 @@ pub struct WebsocketApiArgs {
616632
#[arg(long, env = "WS_API_PORT")]
617633
#[serde(rename = "ws-api-port", skip_serializing_if = "Option::is_none")]
618634
pub ws_api_port: Option<u16>,
635+
636+
/// Token time-to-live in seconds (default is 86400 = 24 hours)
637+
#[arg(long, env = "TOKEN_TTL_SECONDS")]
638+
#[serde(rename = "token-ttl-seconds", skip_serializing_if = "Option::is_none")]
639+
pub token_ttl_seconds: Option<u64>,
640+
641+
/// Token cleanup interval in seconds (default is 300 = 5 minutes)
642+
#[arg(long, env = "TOKEN_CLEANUP_INTERVAL_SECONDS")]
643+
#[serde(
644+
rename = "token-cleanup-interval-seconds",
645+
skip_serializing_if = "Option::is_none"
646+
)]
647+
pub token_cleanup_interval_seconds: Option<u64>,
619648
}
620649

621650
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
@@ -627,13 +656,36 @@ pub struct WebsocketApiConfig {
627656
/// Port to expose api on
628657
#[serde(default = "default_http_gateway_port", rename = "ws-api-port")]
629658
pub port: u16,
659+
660+
/// Token time-to-live in seconds
661+
#[serde(default = "default_token_ttl_seconds", rename = "token-ttl-seconds")]
662+
pub token_ttl_seconds: u64,
663+
664+
/// Token cleanup interval in seconds
665+
#[serde(
666+
default = "default_token_cleanup_interval_seconds",
667+
rename = "token-cleanup-interval-seconds"
668+
)]
669+
pub token_cleanup_interval_seconds: u64,
670+
}
671+
672+
#[inline]
673+
const fn default_token_ttl_seconds() -> u64 {
674+
86400 // 24 hours
675+
}
676+
677+
#[inline]
678+
const fn default_token_cleanup_interval_seconds() -> u64 {
679+
300 // 5 minutes
630680
}
631681

632682
impl From<SocketAddr> for WebsocketApiConfig {
633683
fn from(addr: SocketAddr) -> Self {
634684
Self {
635685
address: addr.ip(),
636686
port: addr.port(),
687+
token_ttl_seconds: default_token_ttl_seconds(),
688+
token_cleanup_interval_seconds: default_token_cleanup_interval_seconds(),
637689
}
638690
}
639691
}
@@ -644,6 +696,8 @@ impl Default for WebsocketApiConfig {
644696
Self {
645697
address: default_listening_address(),
646698
port: default_http_gateway_port(),
699+
token_ttl_seconds: default_token_ttl_seconds(),
700+
token_cleanup_interval_seconds: default_token_cleanup_interval_seconds(),
647701
}
648702
}
649703
}

crates/core/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub mod dev_tool {
6060
use super::*;
6161
pub use crate::config::Config;
6262
pub use client_events::{
63-
test::MemoryEventsGen, test::NetworkEventGenerator, ClientEventsProxy, ClientId,
63+
test::MemoryEventsGen, test::NetworkEventGenerator, AuthToken, ClientEventsProxy, ClientId,
6464
OpenRequest,
6565
};
6666
pub use contract::{storages::Storage, Executor, OperationMode};

crates/core/src/node/mod.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,9 +1334,8 @@ pub async fn run_local_node(
13341334
ClientRequest::DelegateOp(op) => {
13351335
let attested_contract = token.and_then(|token| {
13361336
gw.attested_contracts
1337-
.read()
1338-
.ok()
1339-
.and_then(|guard| guard.get(&token).map(|(t, _)| *t))
1337+
.get(&token)
1338+
.map(|entry| entry.value().contract_id)
13401339
});
13411340
let op_name = match op {
13421341
DelegateRequest::RegisterDelegate { .. } => "RegisterDelegate",
@@ -1356,17 +1355,6 @@ pub async fn run_local_node(
13561355
if let Some(cause) = cause {
13571356
tracing::info!("disconnecting cause: {cause}");
13581357
}
1359-
// FIXME: We're not removing tokens on disconnect to allow WebSocket connections
1360-
// to use them for authentication. We should implement a proper token expiration
1361-
// mechanism instead of keeping them forever or removing them immediately.
1362-
// if let Ok(mut guard) = gw.attested_contracts.write() {
1363-
// if let Some(rm_token) = guard
1364-
// .iter()
1365-
// .find_map(|(k, (_, eid))| (eid == &id).then(|| k.clone()))
1366-
// {
1367-
// guard.remove(&rm_token);
1368-
// }
1369-
// }
13701358
continue;
13711359
}
13721360
_ => Err(ExecutorError::other(anyhow::anyhow!("not supported"))),

crates/core/src/server/http_gateway.rs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::collections::HashMap;
22
use std::net::{IpAddr, SocketAddr};
3-
use std::sync::{Arc, RwLock};
3+
use std::sync::Arc;
4+
use std::time::Instant;
5+
6+
use dashmap::DashMap;
47

58
use axum::extract::Path;
69
use axum::response::IntoResponse;
@@ -31,19 +34,42 @@ impl std::ops::Deref for HttpGatewayRequest {
3134
}
3235
}
3336

34-
pub type AttestedContractMap = Arc<RwLock<HashMap<AuthToken, (ContractInstanceId, ClientId)>>>;
37+
/// Represents an attested contract entry with metadata for token expiration.
38+
#[derive(Clone, Debug)]
39+
pub struct AttestedContract {
40+
/// The contract instance ID
41+
pub contract_id: ContractInstanceId,
42+
/// The client ID associated with this token
43+
pub client_id: ClientId,
44+
/// Timestamp of when the token was last accessed (for expiration tracking)
45+
pub last_accessed: Instant,
46+
}
47+
48+
impl AttestedContract {
49+
/// Create a new attested contract entry
50+
pub fn new(contract_id: ContractInstanceId, client_id: ClientId) -> Self {
51+
Self {
52+
contract_id,
53+
client_id,
54+
last_accessed: Instant::now(),
55+
}
56+
}
57+
}
58+
59+
/// Maps authentication tokens to attested contract metadata.
60+
pub type AttestedContractMap = Arc<DashMap<AuthToken, AttestedContract>>;
3561

3662
/// A gateway to access and interact with contracts through an HTTP interface.
37-
pub(crate) struct HttpGateway {
38-
pub attested_contracts: AttestedContractMap,
63+
pub struct HttpGateway {
64+
pub(crate) attested_contracts: AttestedContractMap,
3965
proxy_server_request: mpsc::Receiver<ClientConnection>,
4066
response_channels: HashMap<ClientId, mpsc::UnboundedSender<HostCallbackResult>>,
4167
}
4268

4369
impl HttpGateway {
4470
/// Returns the uninitialized axum router to compose with other routing handling or websockets.
4571
pub fn as_router(socket: &SocketAddr) -> (Self, Router) {
46-
let attested_contracts = Arc::new(RwLock::new(HashMap::new()));
72+
let attested_contracts = Arc::new(DashMap::new());
4773
Self::as_router_with_attested_contracts(socket, attested_contracts)
4874
}
4975

@@ -54,6 +80,12 @@ impl HttpGateway {
5480
) -> (Self, Router) {
5581
Self::create_router_v1_with_attested_contracts(socket, attested_contracts)
5682
}
83+
84+
/// Returns a reference to the attested contracts map (for integration testing).
85+
/// This allows tests to verify token expiration behavior.
86+
pub fn attested_contracts(&self) -> &AttestedContractMap {
87+
&self.attested_contracts
88+
}
5789
}
5890

5991
#[derive(Clone, Debug)]
@@ -81,10 +113,9 @@ impl ClientEventsProxy for HttpGateway {
81113
.send(HostCallbackResult::NewId { id: cli_id })
82114
.map_err(|_e| ErrorKind::NodeUnavailable)?;
83115
if let Some((assigned_token, contract)) = assigned_token {
116+
let attested = AttestedContract::new(contract, cli_id);
84117
self.attested_contracts
85-
.write()
86-
.map_err(|_| ErrorKind::FailedOperation)?
87-
.insert(assigned_token.clone(), (contract, cli_id));
118+
.insert(assigned_token.clone(), attested);
88119
tracing::debug!(
89120
?assigned_token,
90121
?contract,

0 commit comments

Comments
 (0)