Skip to content

Commit 0bf32c0

Browse files
committed
Add tests
1 parent 4f182f8 commit 0bf32c0

File tree

4 files changed

+86
-19
lines changed

4 files changed

+86
-19
lines changed

src/compression.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function TranscodingStreams.startproc(codec::ZstdCompressor, mode::Symbol, error
9898
codec.cstream.ptr = ptr
9999
i_code = initialize!(codec.cstream, codec.level)
100100
if iserror(i_code)
101-
error[] = ErrorException("zstd error")
101+
error[] = ErrorException("zstd initialization error")
102102
return :error
103103
end
104104
end

src/decompression.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@ end
3333
# Methods
3434
# -------
3535

36-
function TranscodingStreams.initialize(codec::ZstdDecompressor)
37-
code = initialize!(codec.dstream)
38-
if iserror(code)
39-
zstderror(codec.dstream, code)
40-
end
41-
reset!(codec.dstream.ibuffer)
42-
reset!(codec.dstream.obuffer)
43-
return
44-
end
45-
4636
function TranscodingStreams.finalize(codec::ZstdDecompressor)
4737
if codec.dstream.ptr != C_NULL
4838
code = free!(codec.dstream)
@@ -51,12 +41,22 @@ function TranscodingStreams.finalize(codec::ZstdDecompressor)
5141
end
5242
codec.dstream.ptr = C_NULL
5343
end
54-
reset!(codec.dstream.ibuffer)
55-
reset!(codec.dstream.obuffer)
56-
return
44+
nothing
5745
end
5846

5947
function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, error::Error)
48+
if codec.dstream.ptr == C_NULL
49+
ptr = LibZstd.ZSTD_createDStream()
50+
if ptr == C_NULL
51+
throw(OutOfMemoryError())
52+
end
53+
codec.dstream.ptr = ptr
54+
i_code = initialize!(codec.dstream)
55+
if iserror(i_code)
56+
error[] = ErrorException("zstd initialization error")
57+
return :error
58+
end
59+
end
6060
code = reset!(codec.dstream)
6161
if iserror(code)
6262
error[] = ErrorException("zstd error")
@@ -66,6 +66,9 @@ function TranscodingStreams.startproc(codec::ZstdDecompressor, mode::Symbol, err
6666
end
6767

6868
function TranscodingStreams.process(codec::ZstdDecompressor, input::Memory, output::Memory, error::Error)
69+
if codec.dstream.ptr == C_NULL
70+
error("startproc must be called before process")
71+
end
6972
dstream = codec.dstream
7073
dstream.ibuffer.src = input.ptr
7174
dstream.ibuffer.size = input.size

src/libzstd.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,7 @@ mutable struct DStream
123123
obuffer::OutBuffer
124124

125125
function DStream()
126-
ptr = LibZstd.ZSTD_createDStream()
127-
if ptr == C_NULL
128-
throw(OutOfMemoryError())
129-
end
130-
return new(ptr, InBuffer(), OutBuffer())
126+
return new(C_NULL, InBuffer(), OutBuffer())
131127
end
132128
end
133129
Base.unsafe_convert(::Type{Ptr{LibZstd.ZSTD_DStream}}, dstream::DStream) = dstream.ptr

test/runtests.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,72 @@ include("utils.jl")
158158

159159
include("compress_endOp.jl")
160160
include("static_only_tests.jl")
161+
162+
@testset "reusing a compressor" begin
163+
compressor = ZstdCompressor()
164+
x = rand(UInt8, 1000)
165+
TranscodingStreams.initialize(compressor)
166+
ret1 = transcode(compressor, x)
167+
TranscodingStreams.finalize(compressor)
168+
169+
# compress again using the same compressor
170+
TranscodingStreams.initialize(compressor) # segfault happens here!
171+
ret2 = transcode(compressor, x)
172+
ret3 = transcode(compressor, x)
173+
TranscodingStreams.finalize(compressor)
174+
175+
@test transcode(ZstdDecompressor, ret1) == x
176+
@test transcode(ZstdDecompressor, ret2) == x
177+
@test transcode(ZstdDecompressor, ret3) == x
178+
@test ret1 == ret2
179+
@test ret1 == ret3
180+
181+
decompressor = ZstdDecompressor()
182+
TranscodingStreams.initialize(decompressor)
183+
@test transcode(decompressor, ret1) == x
184+
TranscodingStreams.finalize(decompressor)
185+
186+
TranscodingStreams.initialize(decompressor)
187+
@test transcode(decompressor, ret1) == x
188+
TranscodingStreams.finalize(decompressor)
189+
end
190+
191+
@testset "use after free doesn't segfault" begin
192+
@testset "$(Codec)" for Codec in (ZstdCompressor, ZstdDecompressor)
193+
codec = Codec()
194+
TranscodingStreams.initialize(codec)
195+
TranscodingStreams.finalize(codec)
196+
data = [0x00,0x01]
197+
GC.@preserve data let m = TranscodingStreams.Memory(pointer(data), length(data))
198+
try
199+
TranscodingStreams.expectedsize(codec, m)
200+
catch
201+
end
202+
try
203+
TranscodingStreams.minoutsize(codec, m)
204+
catch
205+
end
206+
try
207+
TranscodingStreams.initialize(codec)
208+
catch
209+
end
210+
try
211+
TranscodingStreams.process(codec, m, m, TranscodingStreams.Error())
212+
catch
213+
end
214+
try
215+
TranscodingStreams.startproc(codec, :read, TranscodingStreams.Error())
216+
catch
217+
end
218+
try
219+
TranscodingStreams.process(codec, m, m, TranscodingStreams.Error())
220+
catch
221+
end
222+
try
223+
TranscodingStreams.finalize(codec)
224+
catch
225+
end
226+
end
227+
end
228+
end
161229
end

0 commit comments

Comments
 (0)