From 4f182f898d361a4ceca0a20954f4b42b21f356b9 Mon Sep 17 00:00:00 2001 From: nhz2 Date: Mon, 16 Sep 2024 12:53:10 -0400 Subject: [PATCH 1/6] Auto initialize in `startproc` --- src/compression.jl | 29 ++++++++++++++++------------- src/libzstd.jl | 6 +----- test/compress_endOp.jl | 3 +++ 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/compression.jl b/src/compression.jl index cabc3f9..9505a9f 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -78,16 +78,6 @@ end # Methods # ------- -function TranscodingStreams.initialize(codec::ZstdCompressor) - code = initialize!(codec.cstream, codec.level) - if iserror(code) - zstderror(codec.cstream, code) - end - reset!(codec.cstream.ibuffer) - reset!(codec.cstream.obuffer) - return -end - function TranscodingStreams.finalize(codec::ZstdCompressor) if codec.cstream.ptr != C_NULL code = free!(codec.cstream) @@ -96,12 +86,22 @@ function TranscodingStreams.finalize(codec::ZstdCompressor) end codec.cstream.ptr = C_NULL end - reset!(codec.cstream.ibuffer) - reset!(codec.cstream.obuffer) - return + nothing end function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error) + if codec.cstream.ptr == C_NULL + ptr = LibZstd.ZSTD_createCStream() + if ptr == C_NULL + throw(OutOfMemoryError()) + end + codec.cstream.ptr = ptr + i_code = initialize!(codec.cstream, codec.level) + if iserror(i_code) + error[] = ErrorException("zstd error") + return :error + end + end code = reset!(codec.cstream, 0 #=unknown source size=#) if iserror(code) error[] = ErrorException("zstd error") @@ -111,6 +111,9 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error end function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, error::Error) + if codec.cstream.ptr == C_NULL + error("startproc must be called before process") + end cstream = codec.cstream ibuffer_starting_pos = UInt(0) if codec.endOp == LibZstd.ZSTD_e_end && diff --git a/src/libzstd.jl b/src/libzstd.jl index c11b1f1..7a557d4 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -44,11 +44,7 @@ mutable struct CStream obuffer::OutBuffer function CStream() - ptr = LibZstd.ZSTD_createCStream() - if ptr == C_NULL - throw(OutOfMemoryError()) - end - return new(ptr, InBuffer(), OutBuffer()) + return new(C_NULL, InBuffer(), OutBuffer()) end end diff --git a/test/compress_endOp.jl b/test/compress_endOp.jl index 0594f1f..28fd9f2 100644 --- a/test/compress_endOp.jl +++ b/test/compress_endOp.jl @@ -4,6 +4,7 @@ using Test @testset "compress! endOp = :continue" begin data = rand(1:100, 1024*1024) cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() cstream.ibuffer.src = pointer(data) cstream.ibuffer.size = sizeof(data) cstream.ibuffer.pos = 0 @@ -24,6 +25,7 @@ end @testset "compress! endOp = :flush" begin data = rand(1:100, 1024*1024) cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() cstream.ibuffer.src = pointer(data) cstream.ibuffer.size = sizeof(data) cstream.ibuffer.pos = 0 @@ -43,6 +45,7 @@ end @testset "compress! endOp = :end" begin data = rand(1:100, 1024*1024) cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() cstream.ibuffer.src = pointer(data) cstream.ibuffer.size = sizeof(data) cstream.ibuffer.pos = 0 From 0bf32c0468b821523d5641fb0deb3036f2e30aed Mon Sep 17 00:00:00 2001 From: nhz2 Date: Mon, 16 Sep 2024 13:11:51 -0400 Subject: [PATCH 2/6] Add tests --- src/compression.jl | 2 +- src/decompression.jl | 29 ++++++++++--------- src/libzstd.jl | 6 +--- test/runtests.jl | 68 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 19 deletions(-) diff --git a/src/compression.jl b/src/compression.jl index 9505a9f..f7b9ee7 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -98,7 +98,7 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error codec.cstream.ptr = ptr i_code = initialize!(codec.cstream, codec.level) if iserror(i_code) - error[] = ErrorException("zstd error") + error[] = ErrorException("zstd initialization error") return :error end end diff --git a/src/decompression.jl b/src/decompression.jl index 765ce2c..6d680bf 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -33,16 +33,6 @@ end # Methods # ------- -function TranscodingStreams.initialize(codec::ZstdDecompressor) - code = initialize!(codec.dstream) - if iserror(code) - zstderror(codec.dstream, code) - end - reset!(codec.dstream.ibuffer) - reset!(codec.dstream.obuffer) - return -end - function TranscodingStreams.finalize(codec::ZstdDecompressor) if codec.dstream.ptr != C_NULL code = free!(codec.dstream) @@ -51,12 +41,22 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor) end codec.dstream.ptr = C_NULL end - reset!(codec.dstream.ibuffer) - reset!(codec.dstream.obuffer) - return + nothing end function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error) + if codec.dstream.ptr == C_NULL + ptr = LibZstd.ZSTD_createDStream() + if ptr == C_NULL + throw(OutOfMemoryError()) + end + codec.dstream.ptr = ptr + i_code = initialize!(codec.dstream) + if iserror(i_code) + error[] = ErrorException("zstd initialization error") + return :error + end + end code = reset!(codec.dstream) if iserror(code) error[] = ErrorException("zstd error") @@ -66,6 +66,9 @@ function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, err end function TranscodingStreams.process(codec::ZstdDecompressor, input::Memory, output::Memory, error::Error) + if codec.dstream.ptr == C_NULL + error("startproc must be called before process") + end dstream = codec.dstream dstream.ibuffer.src = input.ptr dstream.ibuffer.size = input.size diff --git a/src/libzstd.jl b/src/libzstd.jl index 7a557d4..b95ec10 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -123,11 +123,7 @@ mutable struct DStream obuffer::OutBuffer function DStream() - ptr = LibZstd.ZSTD_createDStream() - if ptr == C_NULL - throw(OutOfMemoryError()) - end - return new(ptr, InBuffer(), OutBuffer()) + return new(C_NULL, InBuffer(), OutBuffer()) end end Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_DStream}}, dstream::DStream) = dstream.ptr diff --git a/test/runtests.jl b/test/runtests.jl index a111d9a..73cf82f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -158,4 +158,72 @@ include("utils.jl") include("compress_endOp.jl") include("static_only_tests.jl") + + @testset "reusing a compressor" begin + compressor = ZstdCompressor() + x = rand(UInt8, 1000) + TranscodingStreams.initialize(compressor) + ret1 = transcode(compressor, x) + TranscodingStreams.finalize(compressor) + + # compress again using the same compressor + TranscodingStreams.initialize(compressor) # segfault happens here! + ret2 = transcode(compressor, x) + ret3 = transcode(compressor, x) + TranscodingStreams.finalize(compressor) + + @test transcode(ZstdDecompressor, ret1) == x + @test transcode(ZstdDecompressor, ret2) == x + @test transcode(ZstdDecompressor, ret3) == x + @test ret1 == ret2 + @test ret1 == ret3 + + decompressor = ZstdDecompressor() + TranscodingStreams.initialize(decompressor) + @test transcode(decompressor, ret1) == x + TranscodingStreams.finalize(decompressor) + + TranscodingStreams.initialize(decompressor) + @test transcode(decompressor, ret1) == x + TranscodingStreams.finalize(decompressor) + end + + @testset "use after free doesn't segfault" begin + @testset "$(Codec)" for Codec in (ZstdCompressor, ZstdDecompressor) + codec = Codec() + TranscodingStreams.initialize(codec) + TranscodingStreams.finalize(codec) + data = [0x00,0x01] + GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data)) + try + TranscodingStreams.expectedsize(codec, m) + catch + end + try + TranscodingStreams.minoutsize(codec, m) + catch + end + try + TranscodingStreams.initialize(codec) + catch + end + try + TranscodingStreams.process(codec, m, m, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.startproc(codec, :read, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.process(codec, m, m, TranscodingStreams.Error()) + catch + end + try + TranscodingStreams.finalize(codec) + catch + end + end + end + end end From e906d19db5abf8c0d0d7dc5e7d1947f7d409e920 Mon Sep 17 00:00:00 2001 From: Nathan Zimmerberg <39104088+nhz2@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:42:30 -0400 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Mark Kittisopikul --- src/compression.jl | 5 ++--- src/decompression.jl | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/compression.jl b/src/compression.jl index f7b9ee7..8f2e9b3 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -91,11 +91,10 @@ end function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error) if codec.cstream.ptr == C_NULL - ptr = LibZstd.ZSTD_createCStream() - if ptr == C_NULL + codec.cstream.ptr = LibZstd.ZSTD_createCStream() + if codec.cstream.ptr == C_NULL throw(OutOfMemoryError()) end - codec.cstream.ptr = ptr i_code = initialize!(codec.cstream, codec.level) if iserror(i_code) error[] = ErrorException("zstd initialization error") diff --git a/src/decompression.jl b/src/decompression.jl index 6d680bf..bbe48a6 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -46,11 +46,10 @@ end function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error) if codec.dstream.ptr == C_NULL - ptr = LibZstd.ZSTD_createDStream() - if ptr == C_NULL + codec.dstream.ptr = LibZstd.ZSTD_createDStream() + if codec.dstream.ptr == C_NULL throw(OutOfMemoryError()) end - codec.dstream.ptr = ptr i_code = initialize!(codec.dstream) if iserror(i_code) error[] = ErrorException("zstd initialization error") From 50c6a0f3b3d09dc43a2e245b85ab6397a723d797 Mon Sep 17 00:00:00 2001 From: nhz2 Date: Tue, 17 Sep 2024 14:44:14 -0400 Subject: [PATCH 4/6] add explicit return --- src/compression.jl | 2 +- src/decompression.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compression.jl b/src/compression.jl index 8f2e9b3..fc9361d 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -86,7 +86,7 @@ function TranscodingStreams.finalize(codec::ZstdCompressor) end codec.cstream.ptr = C_NULL end - nothing + return end function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error::Error) diff --git a/src/decompression.jl b/src/decompression.jl index bbe48a6..7ed15cb 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -41,7 +41,7 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor) end codec.dstream.ptr = C_NULL end - nothing + return end function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error) From 23ff1c34c394fcbf5348af1805504cdd5c711b8b Mon Sep 17 00:00:00 2001 From: nhz2 Date: Tue, 17 Sep 2024 15:11:32 -0400 Subject: [PATCH 5/6] Add GC preserve --- test/compress_endOp.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/compress_endOp.jl b/test/compress_endOp.jl index 28fd9f2..f5f120d 100644 --- a/test/compress_endOp.jl +++ b/test/compress_endOp.jl @@ -3,22 +3,22 @@ using Test @testset "compress! endOp = :continue" begin data = rand(1:100, 1024*1024) - cstream = CodecZstd.CStream() - cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() - cstream.ibuffer.src = pointer(data) - cstream.ibuffer.size = sizeof(data) - cstream.ibuffer.pos = 0 - cstream.obuffer.dst = Base.Libc.malloc(sizeof(data)*2) - cstream.obuffer.size = sizeof(data)*2 - cstream.obuffer.pos = 0 - try - GC.@preserve data begin + GC.@preserve data begin + cstream = CodecZstd.CStream() + cstream.ptr = CodecZstd.LibZstd.ZSTD_createCStream() + cstream.ibuffer.src = pointer(data) + cstream.ibuffer.size = sizeof(data) + cstream.ibuffer.pos = 0 + cstream.obuffer.dst = Base.Libc.malloc(sizeof(data)*2) + cstream.obuffer.size = sizeof(data)*2 + cstream.obuffer.pos = 0 + try # default endOp @test CodecZstd.compress!(cstream; endOp=:continue) == 0 @test CodecZstd.find_decompressed_size(cstream.obuffer.dst, cstream.obuffer.pos) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN + finally + Base.Libc.free(cstream.obuffer.dst) end - finally - Base.Libc.free(cstream.obuffer.dst) end end From 60079dd3d5d7da7e32cd75f8361c6584fb36fe24 Mon Sep 17 00:00:00 2001 From: nhz2 Date: Tue, 17 Sep 2024 15:19:11 -0400 Subject: [PATCH 6/6] reset dstream buffers in reset! --- src/libzstd.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/libzstd.jl b/src/libzstd.jl index b95ec10..f9865b6 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -137,6 +137,8 @@ end function reset!(dstream::DStream) # LibZstd.ZSTD_resetDStream is deprecated # https://github.com/facebook/zstd/blob/9d2a45a705e22ad4817b41442949cd0f78597154/lib/zstd.h#L2332-L2339 + reset!(dstream.ibuffer) + reset!(dstream.obuffer) return LibZstd.ZSTD_DCtx_reset(dstream, LibZstd.ZSTD_reset_session_only) end