From a291c27e7ec46345ddde319d6abbc01534940174 Mon Sep 17 00:00:00 2001 From: Denis Cornehl Date: Tue, 11 Nov 2025 21:10:15 +0100 Subject: [PATCH] correctly shutdown encoder when using async-compression --- src/storage/compression.rs | 32 +++++++++++++------- src/storage/mod.rs | 62 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 14 deletions(-) diff --git a/src/storage/compression.rs b/src/storage/compression.rs index 4c635fd06..87e267058 100644 --- a/src/storage/compression.rs +++ b/src/storage/compression.rs @@ -90,28 +90,38 @@ pub fn compress(content: impl Read, algorithm: CompressionAlgorithm) -> Result( - output_sink: impl AsyncWrite + Unpin + Send + 'a, +/// async compression, reads from an AsyncRead, writes to an AsyncWrite. +pub async fn compress_async<'a, R, W>( + mut reader: R, + writer: W, algorithm: CompressionAlgorithm, -) -> Box { +) -> io::Result<()> +where + R: AsyncRead + Unpin + Send + 'a, + W: AsyncWrite + Unpin + Send + 'a, +{ use async_compression::tokio::write; - use tokio::io; + use tokio::io::{self, AsyncWriteExt as _}; match algorithm { CompressionAlgorithm::Zstd => { - Box::new(io::BufWriter::new(write::ZstdEncoder::new(output_sink))) + let mut enc = write::ZstdEncoder::new(writer); + io::copy(&mut reader, &mut enc).await?; + enc.shutdown().await?; } CompressionAlgorithm::Bzip2 => { - Box::new(io::BufWriter::new(write::BzEncoder::new(output_sink))) + let mut enc = write::BzEncoder::new(writer); + io::copy(&mut reader, &mut enc).await?; + enc.shutdown().await?; } CompressionAlgorithm::Gzip => { - Box::new(io::BufWriter::new(write::GzipEncoder::new(output_sink))) + let mut enc = write::GzipEncoder::new(writer); + io::copy(&mut reader, &mut enc).await?; + enc.shutdown().await?; } } + + Ok(()) } /// Wrap an AsyncRead for decompression. diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 5e4178825..6cf4bba09 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -5,7 +5,7 @@ mod s3; pub use self::compression::{CompressionAlgorithm, CompressionAlgorithms, compress, decompress}; use self::{ - compression::{wrap_reader_for_decompression, wrap_writer_for_compression}, + compression::{compress_async, wrap_reader_for_decompression}, database::DatabaseBackend, s3::S3Backend, }; @@ -543,9 +543,10 @@ impl AsyncStorage { .await?; let mut buf: Vec = Vec::new(); - tokio::io::copy( + compress_async( &mut tokio::io::BufReader::new(tokio::fs::File::open(&local_index_path).await?), - &mut wrap_writer_for_compression(&mut buf, alg), + &mut buf, + alg, ) .await?; buf @@ -1014,6 +1015,61 @@ mod test { use std::env; use test_case::test_case; + #[tokio::test] + #[test_case(CompressionAlgorithm::Zstd)] + #[test_case(CompressionAlgorithm::Bzip2)] + #[test_case(CompressionAlgorithm::Gzip)] + async fn test_async_compression(alg: CompressionAlgorithm) -> Result<()> { + const CONTENT: &[u8] = b"Hello, world! Hello, world! Hello, world! Hello, world!"; + + let compressed_index_content = { + let mut buf: Vec = Vec::new(); + compress_async(&mut io::Cursor::new(CONTENT.to_vec()), &mut buf, alg).await?; + buf + }; + + { + // try low-level async decompression + let mut decompressed_buf: Vec = Vec::new(); + let mut reader = wrap_reader_for_decompression( + io::Cursor::new(compressed_index_content.clone()), + alg, + ); + + tokio::io::copy(&mut reader, &mut io::Cursor::new(&mut decompressed_buf)).await?; + + assert_eq!(decompressed_buf, CONTENT); + } + + { + // try sync decompression + let decompressed_buf: Vec = decompress( + io::Cursor::new(compressed_index_content.clone()), + alg, + usize::MAX, + )?; + + assert_eq!(decompressed_buf, CONTENT); + } + + // try decompress via storage API + let stream = StreamingBlob { + path: "some_path.db".into(), + mime: mime::APPLICATION_OCTET_STREAM, + date_updated: Utc::now(), + compression: Some(alg), + content_length: compressed_index_content.len(), + content: Box::new(io::Cursor::new(compressed_index_content)), + }; + + let blob = stream.materialize(usize::MAX).await?; + + assert_eq!(blob.compression, None); + assert_eq!(blob.content, CONTENT); + + Ok(()) + } + #[test_case("latest", RustdocJsonFormatVersion::Latest)] #[test_case("42", RustdocJsonFormatVersion::Version(42))] fn test_json_format_version(input: &str, expected: RustdocJsonFormatVersion) {