@@ -95,7 +95,7 @@ struct Inner {
9595 #[ debug( "MimeClassifier" ) ]
9696 mime_classifier : MimeClassifier ,
9797 /// Cache of hashes to mime types
98- mime_cache : Mutex < LruCache < Hash , ( u64 , Mime ) > > ,
98+ mime_cache : Mutex < LruCache < ( Hash , Option < String > ) , ( u64 , Mime ) > > ,
9999 /// Cache of hashes to collections
100100 collection_cache : Mutex < LruCache < Hash , Collection > > ,
101101}
@@ -175,25 +175,19 @@ async fn get_collection(
175175 return Ok ( res. clone ( ) ) ;
176176 }
177177 let ( collection, headers) = get_collection_inner ( hash, connection, true ) . await ?;
178- let mimes = headers
179- . into_iter ( )
180- . map ( |( hash, size, header) | {
181- let mime = gateway. mime_classifier . classify (
182- mime_classifier:: LoadContext :: Browsing ,
183- mime_classifier:: NoSniffFlag :: Off ,
184- mime_classifier:: ApacheBugFlag :: On ,
185- & None ,
186- & header,
187- ) ;
188- ( hash, size, mime)
189- } )
190- . collect :: < Vec < _ > > ( ) ;
191- {
192- let mut cache = gateway. mime_cache . lock ( ) . unwrap ( ) ;
193- for ( hash, size, mime) in mimes {
194- cache. put ( hash, ( size, mime) ) ;
195- }
178+
179+ let mut cache = gateway. mime_cache . lock ( ) . unwrap ( ) ;
180+ for ( name, hash) in collection. iter ( ) {
181+ let ext = get_extension ( name) ;
182+ let Some ( ( hash, size, data) ) = headers. iter ( ) . find ( |( h, _, _) | h == hash) else {
183+ tracing:: debug!( "hash {hash:?} for name {name:?} not found in headers" ) ;
184+ continue ;
185+ } ;
186+ let mime = get_mime_from_ext_and_data ( ext. as_deref ( ) , & data, & gateway. mime_classifier ) ;
187+ let key = ( * hash, ext) ;
188+ cache. put ( key, ( * size, mime) ) ;
196189 }
190+ drop ( cache) ;
197191
198192 gateway
199193 . collection_cache
@@ -203,9 +197,16 @@ async fn get_collection(
203197 Ok ( collection)
204198}
205199
200+ fn get_extension ( name : & str ) -> Option < String > {
201+ std:: path:: Path :: new ( name)
202+ . extension ( )
203+ . map ( |s| s. to_string_lossy ( ) . to_string ( ) )
204+ }
205+
206206/// Get the mime type for a hash from the remote node.
207207async fn get_mime_type_inner (
208208 hash : & Hash ,
209+ ext : Option < & str > ,
209210 connection : & quinn:: Connection ,
210211 mime_classifier : & MimeClassifier ,
211212) -> anyhow:: Result < ( u64 , Mime ) > {
@@ -223,31 +224,46 @@ async fn get_mime_type_inner(
223224 anyhow:: bail!( "unexpected response" ) ;
224225 } ;
225226 let _stats = at_closing. next ( ) . await ?;
227+ let mime = get_mime_from_ext_and_data ( ext, & data, mime_classifier) ;
228+ Ok ( ( size, mime) )
229+ }
230+
231+ fn get_mime_from_ext_and_data (
232+ ext : Option < & str > ,
233+ data : & [ u8 ] ,
234+ mime_classifier : & MimeClassifier ,
235+ ) -> Mime {
226236 let context = mime_classifier:: LoadContext :: Browsing ;
227- let no_sniff_flag = mime_classifier:: NoSniffFlag :: Off ;
237+ let no_sniff_flag = mime_classifier:: NoSniffFlag :: On ;
228238 let apache_bug_flag = mime_classifier:: ApacheBugFlag :: On ;
229- let supplied_type = None ;
230- let mime = mime_classifier. classify (
239+ let supplied_type = match ext {
240+ None => None ,
241+ Some ( ext) => mime_guess:: from_ext ( ext) . first ( ) ,
242+ } ;
243+ mime_classifier. classify (
231244 context,
232245 no_sniff_flag,
233246 apache_bug_flag,
234247 & supplied_type,
235- & data,
236- ) ;
237- Ok ( ( size, mime) )
248+ data,
249+ )
238250}
239251
240252/// Get the mime type for a hash, either from the cache or by requesting it from the node.
241253async fn get_mime_type (
242254 gateway : & Gateway ,
243255 hash : & Hash ,
256+ name : Option < & str > ,
244257 connection : & quinn:: Connection ,
245258) -> anyhow:: Result < ( u64 , Mime ) > {
246- if let Some ( sm) = gateway. mime_cache . lock ( ) . unwrap ( ) . get ( hash) {
259+ let ext = name. map ( |n| get_extension ( n) ) . flatten ( ) ;
260+ let key = ( * hash, ext. clone ( ) ) ;
261+ if let Some ( sm) = gateway. mime_cache . lock ( ) . unwrap ( ) . get ( & key) {
247262 return Ok ( sm. clone ( ) ) ;
248263 }
249- let sm = get_mime_type_inner ( hash, connection, & gateway. mime_classifier ) . await ?;
250- gateway. mime_cache . lock ( ) . unwrap ( ) . put ( * hash, sm. clone ( ) ) ;
264+ let sm =
265+ get_mime_type_inner ( hash, ext. as_deref ( ) , connection, & gateway. mime_classifier ) . await ?;
266+ gateway. mime_cache . lock ( ) . unwrap ( ) . put ( key, sm. clone ( ) ) ;
251267 Ok ( sm)
252268}
253269
@@ -259,7 +275,7 @@ async fn handle_local_blob_request(
259275) -> std:: result:: Result < Response < Body > , AppError > {
260276 let connection = gateway. get_default_connection ( ) . await ?;
261277 let byte_range = parse_byte_range ( req) . await ?;
262- let res = forward_range ( & gateway, connection, & blake3_hash, byte_range) . await ?;
278+ let res = forward_range ( & gateway, connection, & blake3_hash, None , byte_range) . await ?;
263279 Ok ( res)
264280}
265281
@@ -299,7 +315,7 @@ async fn handle_ticket_index(
299315 let hash = ticket. hash ( ) ;
300316 let prefix = format ! ( "/ticket/{}" , ticket) ;
301317 let res = match ticket. format ( ) {
302- BlobFormat :: Raw => forward_range ( & gateway, connection, & hash, byte_range)
318+ BlobFormat :: Raw => forward_range ( & gateway, connection, & hash, None , byte_range)
303319 . await ?
304320 . into_response ( ) ,
305321 BlobFormat :: HashSeq => collection_index ( & gateway, connection, & hash, & prefix)
@@ -345,7 +361,8 @@ async fn collection_index(
345361 for ( name, child_hash) in collection. iter ( ) {
346362 let url = format ! ( "{}/{}" , link_prefix, name) ;
347363 let url = encode_relative_url ( & url) ?;
348- let smo = gateway. mime_cache . lock ( ) . unwrap ( ) . get ( child_hash) . cloned ( ) ;
364+ let key = ( * child_hash, get_extension ( name) ) ;
365+ let smo = gateway. mime_cache . lock ( ) . unwrap ( ) . get ( & key) . cloned ( ) ;
349366 res. push_str ( & format ! ( "<a href=\" {}\" >{}</a>" , url, name, ) ) ;
350367 if let Some ( ( size, mime) ) = smo {
351368 res. push_str ( & format ! ( " ({}, {})" , mime, indicatif:: HumanBytes ( size) ) ) ;
@@ -373,7 +390,7 @@ async fn forward_collection_range(
373390 let collection = get_collection ( gateway, hash, & connection) . await ?;
374391 for ( name, hash) in collection. iter ( ) {
375392 if name == suffix {
376- let res = forward_range ( gateway, connection, hash, range) . await ?;
393+ let res = forward_range ( gateway, connection, hash, Some ( suffix ) , range) . await ?;
377394 return Ok ( res. into_response ( ) ) ;
378395 } else {
379396 tracing:: trace!( "'{}' != '{}'" , name, suffix) ;
@@ -400,16 +417,17 @@ async fn forward_range(
400417 gateway : & Gateway ,
401418 connection : quinn:: Connection ,
402419 hash : & Hash ,
420+ name : Option < & str > ,
403421 ( start, end) : ( Option < u64 > , Option < u64 > ) ,
404422) -> anyhow:: Result < Response < Body > > {
405423 // we need both byte ranges and chunk ranges.
406424 // chunk ranges to request data, and byte ranges to return the data.
407- tracing:: debug!( "forward_range {:?} {:?}" , start, end) ;
425+ tracing:: debug!( "forward_range {:?} {:?} (name {name:?}) " , start, end) ;
408426
409427 let byte_ranges = to_byte_range ( start, end) ;
410428 let chunk_ranges = to_chunk_range ( start, end) ;
411429 tracing:: debug!( "got connection" ) ;
412- let ( _size, mime) = get_mime_type ( gateway, hash, & connection) . await ?;
430+ let ( _size, mime) = get_mime_type ( gateway, hash, name , & connection) . await ?;
413431 tracing:: debug!( "mime: {}" , mime) ;
414432 let chunk_ranges = RangeSpecSeq :: from_ranges ( vec ! [ chunk_ranges] ) ;
415433 let request = iroh:: bytes:: protocol:: GetRequest :: new ( * hash, chunk_ranges. clone ( ) ) ;
0 commit comments