1+ use dashmap:: DashMap ;
12use deno_core:: error:: AnyError ;
23use futures:: io:: AllowStdIo ;
34use once_cell:: sync:: Lazy ;
45use reqwest:: Url ;
5- use std:: collections:: HashMap ;
66use std:: hash:: Hasher ;
77use std:: sync:: Arc ;
88use std:: sync:: Mutex ;
9- use tokio:: sync:: Mutex as AsyncMutex ;
109use tokio_util:: compat:: FuturesAsyncWriteCompatExt ;
1110use tracing:: debug;
1211use tracing:: instrument;
@@ -26,8 +25,8 @@ use ort::session::Session;
2625
2726use crate :: onnx:: ensure_onnx_env_init;
2827
29- static SESSIONS : Lazy < AsyncMutex < HashMap < String , Arc < Mutex < Session > > > > > =
30- Lazy :: new ( || AsyncMutex :: new ( HashMap :: new ( ) ) ) ;
28+ static SESSIONS : Lazy < DashMap < String , Arc < Mutex < Session > > > > =
29+ Lazy :: new ( DashMap :: new) ;
3130
3231#[ derive( Debug ) ]
3332pub struct SessionWithId {
@@ -136,16 +135,14 @@ pub(crate) async fn load_session_from_bytes(
136135 faster_hex:: hex_string ( & hasher. finish ( ) . to_be_bytes ( ) )
137136 } ;
138137
139- let mut sessions = SESSIONS . lock ( ) . await ;
140-
141- if let Some ( session) = sessions. get ( & session_id) {
138+ if let Some ( session) = SESSIONS . get ( & session_id) {
142139 return Ok ( ( session_id, session. clone ( ) ) . into ( ) ) ;
143140 }
144141
145142 trace ! ( session_id, "new session" ) ;
146143 let session = create_session ( model_bytes) ?;
147144
148- sessions . insert ( session_id. clone ( ) , session. clone ( ) ) ;
145+ SESSIONS . insert ( session_id. clone ( ) , session. clone ( ) ) ;
149146
150147 Ok ( ( session_id, session) . into ( ) )
151148}
@@ -156,9 +153,7 @@ pub(crate) async fn load_session_from_url(
156153) -> Result < SessionWithId , Error > {
157154 let session_id = fxhash:: hash ( model_url. as_str ( ) ) . to_string ( ) ;
158155
159- let mut sessions = SESSIONS . lock ( ) . await ;
160-
161- if let Some ( session) = sessions. get ( & session_id) {
156+ if let Some ( session) = SESSIONS . get ( & session_id) {
162157 debug ! ( session_id, "use existing session" ) ;
163158 return Ok ( ( session_id, session. clone ( ) ) . into ( ) ) ;
164159 }
@@ -174,22 +169,23 @@ pub(crate) async fn load_session_from_url(
174169 let session = create_session ( model_bytes. as_slice ( ) ) ?;
175170
176171 debug ! ( session_id, "new session" ) ;
177- sessions . insert ( session_id. clone ( ) , session. clone ( ) ) ;
172+ SESSIONS . insert ( session_id. clone ( ) , session. clone ( ) ) ;
178173
179174 Ok ( ( session_id, session) . into ( ) )
180175}
181176
182177pub ( crate ) async fn get_session ( id : & str ) -> Option < Arc < Mutex < Session > > > {
183- SESSIONS . lock ( ) . await . get ( id ) . cloned ( )
178+ SESSIONS . get ( id ) . map ( |value| value . pair ( ) . 1 . clone ( ) )
184179}
185180
186181pub async fn cleanup ( ) -> Result < usize , AnyError > {
187182 let mut remove_counter = 0 ;
188183 {
189- let mut guard = SESSIONS . lock ( ) . await ;
184+ // let mut guard = SESSIONS.lock().await;
190185 let mut to_be_removed = vec ! [ ] ;
191186
192- for ( key, session) in & mut * guard {
187+ for v in SESSIONS . iter ( ) {
188+ let ( key, session) = v. pair ( ) ;
193189 // Since we're currently referencing the session at this point
194190 // It also will increments the counter, so we need to check: counter > 1
195191 if Arc :: strong_count ( session) > 1 {
@@ -200,7 +196,7 @@ pub async fn cleanup() -> Result<usize, AnyError> {
200196 }
201197
202198 for key in to_be_removed {
203- let old_store = guard . remove ( & key) ;
199+ let old_store = SESSIONS . remove ( & key) ;
204200 debug_assert ! ( old_store. is_some( ) ) ;
205201
206202 remove_counter += 1 ;
0 commit comments