Skip to content

Commit 22b93e4

Browse files
authored
Merge pull request #25 from n0-computer/gateway/mime-from-name
feat: use file extensions from collection names for MIME guessing
2 parents 1575736 + 9209561 commit 22b93e4

File tree

3 files changed

+73
-34
lines changed

3 files changed

+73
-34
lines changed

iroh-gateway/Cargo.lock

Lines changed: 20 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

iroh-gateway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ tokio-rustls-acme = { version = "0.2.0", features = ["axum"] }
3030
hyper-util = "0.1.2"
3131
rustls-pemfile = "1.0.2"
3232
tower-service = "0.3.2"
33+
mime_guess = "2.0.4"

iroh-gateway/src/main.rs

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ struct Inner {
9595
#[debug("MimeClassifier")]
9696
mime_classifier: MimeClassifier,
9797
/// Cache of hashes to mime types
98-
mime_cache: Mutex<LruCache<Hash, (u64, Mime)>>,
98+
mime_cache: Mutex<LruCache<(Hash, Option<String>), (u64, Mime)>>,
9999
/// Cache of hashes to collections
100100
collection_cache: Mutex<LruCache<Hash, Collection>>,
101101
}
@@ -175,25 +175,19 @@ async fn get_collection(
175175
return Ok(res.clone());
176176
}
177177
let (collection, headers) = get_collection_inner(hash, connection, true).await?;
178-
let mimes = headers
179-
.into_iter()
180-
.map(|(hash, size, header)| {
181-
let mime = gateway.mime_classifier.classify(
182-
mime_classifier::LoadContext::Browsing,
183-
mime_classifier::NoSniffFlag::Off,
184-
mime_classifier::ApacheBugFlag::On,
185-
&None,
186-
&header,
187-
);
188-
(hash, size, mime)
189-
})
190-
.collect::<Vec<_>>();
191-
{
192-
let mut cache = gateway.mime_cache.lock().unwrap();
193-
for (hash, size, mime) in mimes {
194-
cache.put(hash, (size, mime));
195-
}
178+
179+
let mut cache = gateway.mime_cache.lock().unwrap();
180+
for (name, hash) in collection.iter() {
181+
let ext = get_extension(name);
182+
let Some((hash, size, data)) = headers.iter().find(|(h, _, _)| h == hash) else {
183+
tracing::debug!("hash {hash:?} for name {name:?} not found in headers");
184+
continue;
185+
};
186+
let mime = get_mime_from_ext_and_data(ext.as_deref(), &data, &gateway.mime_classifier);
187+
let key = (*hash, ext);
188+
cache.put(key, (*size, mime));
196189
}
190+
drop(cache);
197191

198192
gateway
199193
.collection_cache
@@ -203,9 +197,16 @@ async fn get_collection(
203197
Ok(collection)
204198
}
205199

200+
fn get_extension(name: &str) -> Option<String> {
201+
std::path::Path::new(name)
202+
.extension()
203+
.map(|s| s.to_string_lossy().to_string())
204+
}
205+
206206
/// Get the mime type for a hash from the remote node.
207207
async fn get_mime_type_inner(
208208
hash: &Hash,
209+
ext: Option<&str>,
209210
connection: &quinn::Connection,
210211
mime_classifier: &MimeClassifier,
211212
) -> anyhow::Result<(u64, Mime)> {
@@ -223,31 +224,46 @@ async fn get_mime_type_inner(
223224
anyhow::bail!("unexpected response");
224225
};
225226
let _stats = at_closing.next().await?;
227+
let mime = get_mime_from_ext_and_data(ext, &data, mime_classifier);
228+
Ok((size, mime))
229+
}
230+
231+
fn get_mime_from_ext_and_data(
232+
ext: Option<&str>,
233+
data: &[u8],
234+
mime_classifier: &MimeClassifier,
235+
) -> Mime {
226236
let context = mime_classifier::LoadContext::Browsing;
227-
let no_sniff_flag = mime_classifier::NoSniffFlag::Off;
237+
let no_sniff_flag = mime_classifier::NoSniffFlag::On;
228238
let apache_bug_flag = mime_classifier::ApacheBugFlag::On;
229-
let supplied_type = None;
230-
let mime = mime_classifier.classify(
239+
let supplied_type = match ext {
240+
None => None,
241+
Some(ext) => mime_guess::from_ext(ext).first(),
242+
};
243+
mime_classifier.classify(
231244
context,
232245
no_sniff_flag,
233246
apache_bug_flag,
234247
&supplied_type,
235-
&data,
236-
);
237-
Ok((size, mime))
248+
data,
249+
)
238250
}
239251

240252
/// Get the mime type for a hash, either from the cache or by requesting it from the node.
241253
async fn get_mime_type(
242254
gateway: &Gateway,
243255
hash: &Hash,
256+
name: Option<&str>,
244257
connection: &quinn::Connection,
245258
) -> anyhow::Result<(u64, Mime)> {
246-
if let Some(sm) = gateway.mime_cache.lock().unwrap().get(hash) {
259+
let ext = name.map(|n| get_extension(n)).flatten();
260+
let key = (*hash, ext.clone());
261+
if let Some(sm) = gateway.mime_cache.lock().unwrap().get(&key) {
247262
return Ok(sm.clone());
248263
}
249-
let sm = get_mime_type_inner(hash, connection, &gateway.mime_classifier).await?;
250-
gateway.mime_cache.lock().unwrap().put(*hash, sm.clone());
264+
let sm =
265+
get_mime_type_inner(hash, ext.as_deref(), connection, &gateway.mime_classifier).await?;
266+
gateway.mime_cache.lock().unwrap().put(key, sm.clone());
251267
Ok(sm)
252268
}
253269

@@ -259,7 +275,7 @@ async fn handle_local_blob_request(
259275
) -> std::result::Result<Response<Body>, AppError> {
260276
let connection = gateway.get_default_connection().await?;
261277
let byte_range = parse_byte_range(req).await?;
262-
let res = forward_range(&gateway, connection, &blake3_hash, byte_range).await?;
278+
let res = forward_range(&gateway, connection, &blake3_hash, None, byte_range).await?;
263279
Ok(res)
264280
}
265281

@@ -299,7 +315,7 @@ async fn handle_ticket_index(
299315
let hash = ticket.hash();
300316
let prefix = format!("/ticket/{}", ticket);
301317
let res = match ticket.format() {
302-
BlobFormat::Raw => forward_range(&gateway, connection, &hash, byte_range)
318+
BlobFormat::Raw => forward_range(&gateway, connection, &hash, None, byte_range)
303319
.await?
304320
.into_response(),
305321
BlobFormat::HashSeq => collection_index(&gateway, connection, &hash, &prefix)
@@ -345,7 +361,8 @@ async fn collection_index(
345361
for (name, child_hash) in collection.iter() {
346362
let url = format!("{}/{}", link_prefix, name);
347363
let url = encode_relative_url(&url)?;
348-
let smo = gateway.mime_cache.lock().unwrap().get(child_hash).cloned();
364+
let key = (*child_hash, get_extension(name));
365+
let smo = gateway.mime_cache.lock().unwrap().get(&key).cloned();
349366
res.push_str(&format!("<a href=\"{}\">{}</a>", url, name,));
350367
if let Some((size, mime)) = smo {
351368
res.push_str(&format!(" ({}, {})", mime, indicatif::HumanBytes(size)));
@@ -373,7 +390,7 @@ async fn forward_collection_range(
373390
let collection = get_collection(gateway, hash, &connection).await?;
374391
for (name, hash) in collection.iter() {
375392
if name == suffix {
376-
let res = forward_range(gateway, connection, hash, range).await?;
393+
let res = forward_range(gateway, connection, hash, Some(suffix), range).await?;
377394
return Ok(res.into_response());
378395
} else {
379396
tracing::trace!("'{}' != '{}'", name, suffix);
@@ -400,16 +417,17 @@ async fn forward_range(
400417
gateway: &Gateway,
401418
connection: quinn::Connection,
402419
hash: &Hash,
420+
name: Option<&str>,
403421
(start, end): (Option<u64>, Option<u64>),
404422
) -> anyhow::Result<Response<Body>> {
405423
// we need both byte ranges and chunk ranges.
406424
// chunk ranges to request data, and byte ranges to return the data.
407-
tracing::debug!("forward_range {:?} {:?}", start, end);
425+
tracing::debug!("forward_range {:?} {:?} (name {name:?})", start, end);
408426

409427
let byte_ranges = to_byte_range(start, end);
410428
let chunk_ranges = to_chunk_range(start, end);
411429
tracing::debug!("got connection");
412-
let (_size, mime) = get_mime_type(gateway, hash, &connection).await?;
430+
let (_size, mime) = get_mime_type(gateway, hash, name, &connection).await?;
413431
tracing::debug!("mime: {}", mime);
414432
let chunk_ranges = RangeSpecSeq::from_ranges(vec![chunk_ranges]);
415433
let request = iroh::bytes::protocol::GetRequest::new(*hash, chunk_ranges.clone());

0 commit comments

Comments
 (0)