Skip to content
27 changes: 21 additions & 6 deletions src/compression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,30 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, err::
end
# TODO Allow setting other parameters here.
end
code = reset!(codec.cstream, 0 #=unknown source size=#)
if iserror(code)
# This is unreachable according to zstd.h
err[] = ErrorException("zstd error resetting context.")
return :error
end
reset!(codec.cstream)
return :ok
end

@static if isdefined(TranscodingStreams, :pledgeinsize) # Defined in v0.11.3
function TranscodingStreams.pledgeinsize(codec::ZstdCompressor, insize::Int64, err::Error)::Symbol
if codec.cstream.ptr == C_NULL
error("`startproc` must be called before `pledgeinsize`")
end
srcsize = if signbit(insize)
ZSTD_CONTENTSIZE_UNKNOWN
else
Culonglong(insize)
end
code = LibZstd.ZSTD_CCtx_setPledgedSrcSize(codec.cstream, srcsize)
if iserror(code)
err[] = ErrorException("zstd error setting pledged source size")
:error
else
:ok
end
end
end

function TranscodingStreams.process(codec::ZstdCompressor, input::Memory, output::Memory, err::Error)
if codec.cstream.ptr == C_NULL
error("startproc must be called before process")
Expand Down
18 changes: 6 additions & 12 deletions src/libzstd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,17 @@ Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_CCtx}}, cstream::CStream) = cstream.
Base.unsafe_convert(::Type{Ptr{InBuffer}}, cstream::CStream) = Base.unsafe_convert(Ptr{InBuffer}, cstream.ibuffer)
Base.unsafe_convert(::Type{Ptr{OutBuffer}}, cstream::CStream) = Base.unsafe_convert(Ptr{OutBuffer}, cstream.obuffer)

function reset!(cstream::CStream, srcsize::Integer)
function reset!(cstream::CStream)
# ZSTD_resetCStream is deprecated
# https://github.com/facebook/zstd/blob/9d2a45a705e22ad4817b41442949cd0f78597154/lib/zstd.h#L2253-L2272
res = LibZstd.ZSTD_CCtx_reset(cstream, LibZstd.ZSTD_reset_session_only)
if iserror(res)
return res
end
if srcsize == 0
# From zstd.h:
# Note: ZSTD_resetCStream() interprets pledgedSrcSize == 0 as ZSTD_CONTENTSIZE_UNKNOWN, but
# ZSTD_CCtx_setPledgedSrcSize() does not do the same, so ZSTD_CONTENTSIZE_UNKNOWN must be
# explicitly specified.
srcsize = ZSTD_CONTENTSIZE_UNKNOWN
end
reset!(cstream.ibuffer)
reset!(cstream.obuffer)
return LibZstd.ZSTD_CCtx_setPledgedSrcSize(cstream, srcsize)
if iserror(res)
# According to zstd.h "Resetting session never fails" so this branch should be unreachable.
error("unreachable")
end
return
end

"""
Expand Down
78 changes: 76 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,17 @@ include("utils.jl")
@test CodecZstd.find_decompressed_size(v) == 22

codec = ZstdCompressor
buffer3 = transcode(codec, b"Hello")
buffer4 = transcode(codec, b"World!")
sink = IOBuffer()
s = TranscodingStream(codec(), sink; stop_on_end=true)
write(s, b"Hello")
close(s)
buffer3 = take!(sink)
@test CodecZstd.find_decompressed_size(buffer3) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transcode with ZstdCompressor now records the decompressed size.

sink = IOBuffer()
s = TranscodingStream(codec(), sink; stop_on_end=true)
write(s, b"Hello")
close(s)
buffer4 = take!(sink)
@test CodecZstd.find_decompressed_size(buffer4) == CodecZstd.ZSTD_CONTENTSIZE_UNKNOWN

write(iob, buffer1)
Expand All @@ -156,6 +164,68 @@ include("utils.jl")
@test CodecZstd.find_decompressed_size(v) == CodecZstd.ZSTD_CONTENTSIZE_ERROR
end

if isdefined(TranscodingStreams, :pledgeinsize)
@testset "pledgeinsize" begin
# when pledgeinsize is available transcode should save the
# decompressed size in a header
for n in [0:30; 1000; 1000000;]
v = transcode(ZstdCompressor, rand(UInt8, n))
@test CodecZstd.find_decompressed_size(v) == n
end

# Test what happens if pledgeinsize promise is broken
d1 = zeros(UInt8, 10000)
d2 = zeros(UInt8, 10000)
GC.@preserve d1 d2 begin
@testset "too many bytes" begin
m1 = TranscodingStreams.Memory(pointer(d1), 1000)
m2 = TranscodingStreams.Memory(pointer(d2), 1000)
codec = ZstdCompressor()
e = TranscodingStreams.Error()
@test TranscodingStreams.startproc(codec, :read, e) === :ok
@test TranscodingStreams.pledgeinsize(codec, Int64(10), e) === :ok
@test TranscodingStreams.process(codec, m1, m2, e) === (0, 0, :error)
@test e[] == ErrorException("zstd compression error: Src size is incorrect")
TranscodingStreams.finalize(codec)
end
@testset "too few bytes" begin
m1 = TranscodingStreams.Memory(pointer(d1), 10)
m2 = TranscodingStreams.Memory(pointer(d2), 1000)
codec = ZstdCompressor()
e = TranscodingStreams.Error()
@test TranscodingStreams.startproc(codec, :read, e) === :ok
@test TranscodingStreams.pledgeinsize(codec, Int64(10000), e) === :ok
@test TranscodingStreams.process(codec, m1, m2, e)[3] === :ok
m1 = TranscodingStreams.Memory(pointer(d1), 0)
@test TranscodingStreams.process(codec, m1, m2, e)[3] === :error
@test e[] == ErrorException("zstd compression error: Src size is incorrect")
TranscodingStreams.finalize(codec)
end
@testset "set pledgeinsize after process" begin
m1 = TranscodingStreams.Memory(pointer(d1), 1000)
m2 = TranscodingStreams.Memory(pointer(d2), 1000)
codec = ZstdCompressor()
e = TranscodingStreams.Error()
@test TranscodingStreams.startproc(codec, :read, e) === :ok
@test TranscodingStreams.process(codec, m1, m2, e)[3] === :ok
@test TranscodingStreams.pledgeinsize(codec, Int64(10000), e) === :error
@test e[] == ErrorException("zstd error setting pledged source size")
TranscodingStreams.finalize(codec)
end
@testset "set unknown pledgeinsize" begin
m1 = TranscodingStreams.Memory(pointer(d1), 1000)
m2 = TranscodingStreams.Memory(pointer(d2), 1000)
codec = ZstdCompressor()
e = TranscodingStreams.Error()
@test TranscodingStreams.startproc(codec, :read, e) === :ok
@test TranscodingStreams.pledgeinsize(codec, Int64(-1), e) === :ok
@test TranscodingStreams.process(codec, m1, m2, e)[3] === :ok
TranscodingStreams.finalize(codec)
end
end
end
end

include("compress_endOp.jl")
include("static_only_tests.jl")

Expand Down Expand Up @@ -195,6 +265,10 @@ include("utils.jl")
TranscodingStreams.finalize(codec)
data = [0x00,0x01]
GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data))
try
TranscodingStreams.pledgeinsize(codec, Int64(10), TranscodingStreams.Error())
catch
end
try
TranscodingStreams.expectedsize(codec, m)
catch
Expand Down
Loading