Skip to content

Commit 3b3d9a0

Browse files
committed
correctly shutdown encoder when using async-compression
1 parent c5ca36d commit 3b3d9a0

File tree

2 files changed

+80
-14
lines changed

2 files changed

+80
-14
lines changed

src/storage/compression.rs

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,28 +90,38 @@ pub fn compress(content: impl Read, algorithm: CompressionAlgorithm) -> Result<V
9090
}
9191
}
9292

93-
/// Wrap an AsyncWrite sink for compression using the specified algorithm.
94-
///
95-
/// Will return an AsyncWrite you can just write data to, we will compress
96-
/// the data, and then write the compressed data into the provided output sink.
97-
pub fn wrap_writer_for_compression<'a>(
98-
output_sink: impl AsyncWrite + Unpin + Send + 'a,
93+
/// async compression, reads from an AsyncRead, writes to an AsyncWrite.
94+
pub async fn compress_async<'a, R, W>(
95+
mut reader: R,
96+
writer: W,
9997
algorithm: CompressionAlgorithm,
100-
) -> Box<dyn AsyncWrite + Unpin + 'a> {
98+
) -> io::Result<()>
99+
where
100+
R: AsyncRead + Unpin + Send + 'a,
101+
W: AsyncWrite + Unpin + Send + 'a,
102+
{
101103
use async_compression::tokio::write;
102-
use tokio::io;
104+
use tokio::io::{self, AsyncWriteExt as _};
103105

104106
match algorithm {
105107
CompressionAlgorithm::Zstd => {
106-
Box::new(io::BufWriter::new(write::ZstdEncoder::new(output_sink)))
108+
let mut enc = write::ZstdEncoder::new(writer);
109+
io::copy(&mut reader, &mut enc).await?;
110+
enc.shutdown().await?;
107111
}
108112
CompressionAlgorithm::Bzip2 => {
109-
Box::new(io::BufWriter::new(write::BzEncoder::new(output_sink)))
113+
let mut enc = write::BzEncoder::new(writer);
114+
io::copy(&mut reader, &mut enc).await?;
115+
enc.shutdown().await?;
110116
}
111117
CompressionAlgorithm::Gzip => {
112-
Box::new(io::BufWriter::new(write::GzipEncoder::new(output_sink)))
118+
let mut enc = write::GzipEncoder::new(writer);
119+
io::copy(&mut reader, &mut enc).await?;
120+
enc.shutdown().await?;
113121
}
114122
}
123+
124+
Ok(())
115125
}
116126

117127
/// Wrap an AsyncRead for decompression.

src/storage/mod.rs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod s3;
55

66
pub use self::compression::{CompressionAlgorithm, CompressionAlgorithms, compress, decompress};
77
use self::{
8-
compression::{wrap_reader_for_decompression, wrap_writer_for_compression},
8+
compression::{compress_async, wrap_reader_for_decompression},
99
database::DatabaseBackend,
1010
s3::S3Backend,
1111
};
@@ -543,9 +543,10 @@ impl AsyncStorage {
543543
.await?;
544544

545545
let mut buf: Vec<u8> = Vec::new();
546-
tokio::io::copy(
546+
compress_async(
547547
&mut tokio::io::BufReader::new(tokio::fs::File::open(&local_index_path).await?),
548-
&mut wrap_writer_for_compression(&mut buf, alg),
548+
&mut buf,
549+
alg,
549550
)
550551
.await?;
551552
buf
@@ -1014,6 +1015,61 @@ mod test {
10141015
use std::env;
10151016
use test_case::test_case;
10161017

1018+
#[tokio::test]
1019+
#[test_case(CompressionAlgorithm::Zstd)]
1020+
#[test_case(CompressionAlgorithm::Bzip2)]
1021+
#[test_case(CompressionAlgorithm::Gzip)]
1022+
async fn test_async_compression(alg: CompressionAlgorithm) -> Result<()> {
1023+
const CONTENT: &[u8] = b"Hello, world! Hello, world! Hello, world! Hello, world!";
1024+
1025+
let compressed_index_content = {
1026+
let mut buf: Vec<u8> = Vec::new();
1027+
compress_async(&mut io::Cursor::new(CONTENT.to_vec()), &mut buf, alg).await?;
1028+
buf
1029+
};
1030+
1031+
{
1032+
// try low-level async decompression
1033+
let mut decompressed_buf: Vec<u8> = Vec::new();
1034+
let mut reader = wrap_reader_for_decompression(
1035+
io::Cursor::new(compressed_index_content.clone()),
1036+
alg,
1037+
);
1038+
1039+
tokio::io::copy(&mut reader, &mut io::Cursor::new(&mut decompressed_buf)).await?;
1040+
1041+
assert_eq!(decompressed_buf, CONTENT);
1042+
}
1043+
1044+
{
1045+
// try sync decompression
1046+
let decompressed_buf: Vec<u8> = decompress(
1047+
io::Cursor::new(compressed_index_content.clone()),
1048+
alg,
1049+
usize::MAX,
1050+
)?;
1051+
1052+
assert_eq!(decompressed_buf, CONTENT);
1053+
}
1054+
1055+
// try decompress via storage API
1056+
let stream = StreamingBlob {
1057+
path: "some_path.db".into(),
1058+
mime: mime::APPLICATION_OCTET_STREAM,
1059+
date_updated: Utc::now(),
1060+
compression: Some(alg),
1061+
content_length: compressed_index_content.len(),
1062+
content: Box::new(io::Cursor::new(compressed_index_content)),
1063+
};
1064+
1065+
let blob = stream.materialize(usize::MAX).await?;
1066+
1067+
assert_eq!(blob.compression, None);
1068+
assert_eq!(blob.content, CONTENT);
1069+
1070+
Ok(())
1071+
}
1072+
10171073
#[test_case("latest", RustdocJsonFormatVersion::Latest)]
10181074
#[test_case("42", RustdocJsonFormatVersion::Version(42))]
10191075
fn test_json_format_version(input: &str, expected: RustdocJsonFormatVersion) {

0 commit comments

Comments
 (0)