diff --git a/src/compression.jl b/src/compression.jl index ee518ea..8a7f841 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -180,15 +180,30 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, err:: end # TODO Allow setting other parameters here. end - code = reset!(codec.cstream, 0 #=unknown source size=#) - if iserror(code) - # This is unreachable according to zstd.h - err[] = ErrorException("zstd error resetting context.") - return :error - end + reset!(codec.cstream) return :ok end +@static if isdefined(TranscodingStreams, :pledgeinsize) # Defined in v0.11.3 + function TranscodingStreams.pledgeinsize(codec::ZstdCompressor, insize::Int64, err::Error)::Symbol + if codec.cstream.ptr == C_NULL + error("`startproc` must be called before `pledgeinsize`") + end + srcsize = if signbit(insize) + ZSTD_CONTENTSIZE_UNKNOWN + else + Culonglong(insize) + end + code = LibZstd.ZSTD_CCtx_setPledgedSrcSize(codec.cstream, srcsize) + if iserror(code) + err[] = ErrorException("zstd error setting pledged source size") + :error + else + :ok + end + end +end + function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, err::Error) if codec.cstream.ptr == C_NULL error("startproc must be called before process") diff --git a/src/libzstd.jl b/src/libzstd.jl index c62bf0c..ca980bf 100644 --- a/src/libzstd.jl +++ b/src/libzstd.jl @@ -60,23 +60,17 @@ Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_CCtx}}, cstream::CStream) = cstream. Base.unsafe_convert(::Type{Ptr{InBuffer}}, cstream::CStream) = Base.unsafe_convert(Ptr{InBuffer}, cstream.ibuffer) Base.unsafe_convert(::Type{Ptr{OutBuffer}}, cstream::CStream) = Base.unsafe_convert(Ptr{OutBuffer}, cstream.obuffer) -function reset!(cstream::CStream, srcsize::Integer) +function reset!(cstream::CStream) # ZSTD_resetCStream is deprecated # https://github.com/facebook/zstd/blob/9d2a45a705e22ad4817b41442949cd0f78597154/lib/zstd.h#L2253-L2272 res = LibZstd.ZSTD_CCtx_reset(cstream, LibZstd.ZSTD_reset_session_only) - if iserror(res) - return res - end - if srcsize == 0 - # From zstd.h: - # Note: ZSTD_resetCStream() interprets pledgedSrcSize == 0 as ZSTD_CONTENTSIZE_UNKNOWN, but - # ZSTD_CCtx_setPledgedSrcSize() does not do the same, so ZSTD_CONTENTSIZE_UNKNOWN must be - # explicitly specified. - srcsize = ZSTD_CONTENTSIZE_UNKNOWN - end reset!(cstream.ibuffer) reset!(cstream.obuffer) - return LibZstd.ZSTD_CCtx_setPledgedSrcSize(cstream, srcsize) + if iserror(res) + # According to zstd.h "Resetting session never fails" so this branch should be unreachable. + error("unreachable") + end + return end """ diff --git a/test/runtests.jl b/test/runtests.jl index e403bf2..5d510c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -130,9 +130,17 @@ include("utils.jl") @test CodecZstd.find_decompressed_size(v) == 22 codec = ZstdCompressor - buffer3 = transcode(codec, b"Hello") - buffer4 = transcode(codec, b"World!") + sink = IOBuffer() + s = TranscodingStream(codec(), sink; stop_on_end=true) + write(s, b"Hello") + close(s) + buffer3 = take!(sink) @test CodecZstd.find_decompressed_size(buffer3) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN + sink = IOBuffer() + s = TranscodingStream(codec(), sink; stop_on_end=true) + write(s, b"Hello") + close(s) + buffer4 = take!(sink) @test CodecZstd.find_decompressed_size(buffer4) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN write(iob, buffer1) @@ -156,6 +164,68 @@ include("utils.jl") @test CodecZstd.find_decompressed_size(v) == CodecZstd.ZSTD_CONTENTSIZE_ERROR end + if isdefined(TranscodingStreams, :pledgeinsize) + @testset "pledgeinsize" begin + # when pledgeinsize is available transcode should save the + # decompressed size in a header + for n in [0:30; 1000; 1000000;] + v = transcode(ZstdCompressor, rand(UInt8, n)) + @test CodecZstd.find_decompressed_size(v) == n + end + + # Test what happens if pledgeinsize promise is broken + d1 = zeros(UInt8, 10000) + d2 = zeros(UInt8, 10000) + GC.@preserve d1 d2 begin + @testset "too many bytes" begin + m1 = TranscodingStreams.Memory(pointer(d1), 1000) + m2 = TranscodingStreams.Memory(pointer(d2), 1000) + codec = ZstdCompressor() + e = TranscodingStreams.Error() + @test TranscodingStreams.startproc(codec, :read, e) === :ok + @test TranscodingStreams.pledgeinsize(codec, Int64(10), e) === :ok + @test TranscodingStreams.process(codec, m1, m2, e) === (0, 0, :error) + @test e[] == ErrorException("zstd compression error: Src size is incorrect") + TranscodingStreams.finalize(codec) + end + @testset "too few bytes" begin + m1 = TranscodingStreams.Memory(pointer(d1), 10) + m2 = TranscodingStreams.Memory(pointer(d2), 1000) + codec = ZstdCompressor() + e = TranscodingStreams.Error() + @test TranscodingStreams.startproc(codec, :read, e) === :ok + @test TranscodingStreams.pledgeinsize(codec, Int64(10000), e) === :ok + @test TranscodingStreams.process(codec, m1, m2, e)[3] === :ok + m1 = TranscodingStreams.Memory(pointer(d1), 0) + @test TranscodingStreams.process(codec, m1, m2, e)[3] === :error + @test e[] == ErrorException("zstd compression error: Src size is incorrect") + TranscodingStreams.finalize(codec) + end + @testset "set pledgeinsize after process" begin + m1 = TranscodingStreams.Memory(pointer(d1), 1000) + m2 = TranscodingStreams.Memory(pointer(d2), 1000) + codec = ZstdCompressor() + e = TranscodingStreams.Error() + @test TranscodingStreams.startproc(codec, :read, e) === :ok + @test TranscodingStreams.process(codec, m1, m2, e)[3] === :ok + @test TranscodingStreams.pledgeinsize(codec, Int64(10000), e) === :error + @test e[] == ErrorException("zstd error setting pledged source size") + TranscodingStreams.finalize(codec) + end + @testset "set unknown pledgeinsize" begin + m1 = TranscodingStreams.Memory(pointer(d1), 1000) + m2 = TranscodingStreams.Memory(pointer(d2), 1000) + codec = ZstdCompressor() + e = TranscodingStreams.Error() + @test TranscodingStreams.startproc(codec, :read, e) === :ok + @test TranscodingStreams.pledgeinsize(codec, Int64(-1), e) === :ok + @test TranscodingStreams.process(codec, m1, m2, e)[3] === :ok + TranscodingStreams.finalize(codec) + end + end + end + end + include("compress_endOp.jl") include("static_only_tests.jl") @@ -195,6 +265,10 @@ include("utils.jl") TranscodingStreams.finalize(codec) data = [0x00,0x01] GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data)) + try + TranscodingStreams.pledgeinsize(codec, Int64(10), TranscodingStreams.Error()) + catch + end try TranscodingStreams.expectedsize(codec, m) catch