Skip to content

Commit 5720e12

Browse files
committed
Fix multiple zlib decoding bugs
1 parent c30f2a1 commit 5720e12

File tree

2 files changed

+55
-10
lines changed

2 files changed

+55
-10
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_zlib.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def _disable_native_zlib(self, flag):
1414
return None
1515
__graalpython__ = GP()
1616

17+
import binascii
1718
import os
19+
import random
20+
import sys
1821
import unittest
1922
import zlib
20-
import binascii
21-
import sys
2223

2324
pintNumber = 98765432109876543210
2425
longNumber = 9876543210
@@ -271,3 +272,35 @@ def test_GR65704():
271272
__graalpython__._disable_native_zlib(False)
272273

273274
assert decompressed == contents
275+
276+
def test_large_chunk():
277+
contents = random.randbytes(5000)
278+
wbits = 31
279+
280+
__graalpython__._disable_native_zlib(True)
281+
282+
compressed = zlib.compress(contents, wbits=wbits)
283+
decompressor = zlib.decompressobj(wbits=wbits)
284+
285+
decompressed = decompressor.decompress(compressed)
286+
287+
__graalpython__._disable_native_zlib(False)
288+
289+
assert decompressed == contents
290+
291+
def test_various_chunks():
292+
contents = random.randbytes(5000)
293+
wbits = 31
294+
295+
__graalpython__._disable_native_zlib(True)
296+
297+
compressed = zlib.compress(contents, wbits=wbits)
298+
decompressor = zlib.decompressobj(wbits=wbits)
299+
300+
decompressed = decompressor.decompress(compressed[:10])
301+
decompressed += decompressor.decompress(compressed[10:200])
302+
decompressed += decompressor.decompress(compressed[200:])
303+
304+
__graalpython__._disable_native_zlib(False)
305+
306+
assert decompressed == contents

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/zlib/JavaDecompress.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
import java.io.ByteArrayInputStream;
4949
import java.io.ByteArrayOutputStream;
50+
import java.io.EOFException;
5051
import java.io.IOException;
5152
import java.io.InputStream;
5253
import java.util.zip.DataFormatException;
@@ -112,7 +113,7 @@ public Inflater getInflater() {
112113
return inf;
113114
}
114115

115-
public void setInput() throws IOException {
116+
public void fillInput() throws IOException {
116117
fill();
117118
}
118119
}
@@ -162,7 +163,7 @@ private static boolean isGZIPStreamReady(DecompressStream stream, byte[] data, i
162163
// GZIPInputStream will read the header during initialization
163164
stream.stream = new GZIPDecompressStream(stream.in);
164165
stream.inflater = stream.stream.getInflater();
165-
stream.stream.setInput();
166+
stream.stream.fillInput();
166167
return true;
167168
}
168169
} catch (ZipException ze) {
@@ -178,7 +179,7 @@ private static boolean isGZIPStreamFinishing(DecompressStream stream, byte[] dat
178179
stream.in.append(data, 0, length);
179180
try {
180181
if (stream.in.length() >= HEADER_TRAILER_SIZE) {
181-
stream.stream.setInput();
182+
stream.stream.fillInput();
182183
// this should trigger reading trailer
183184
stream.stream.read();
184185
stream.stream = null;
@@ -246,19 +247,30 @@ private byte[] createByteArray(byte[] bytes, int length, int maxLength, int bufS
246247
int maxLen = maxLength <= 0 ? Integer.MAX_VALUE : maxLength;
247248
byte[] result = new byte[Math.min(maxLen, bufSize)];
248249

249-
int bytesWritten = result.length;
250250
ByteArrayOutputStream baos = new ByteArrayOutputStream();
251251
boolean zdictIsSet = false;
252-
while (baos.size() < maxLen && bytesWritten == result.length) {
252+
while (baos.size() < maxLen && !stream.inflater.finished()) {
253+
if (stream.inflater.needsInput()) {
254+
if (stream.stream == null) {
255+
break;
256+
}
257+
try {
258+
stream.stream.fillInput();
259+
} catch (EOFException e) {
260+
break;
261+
} catch (IOException e) {
262+
throw CompilerDirectives.shouldNotReachHere(e);
263+
}
264+
}
265+
int bytesWritten;
253266
try {
254267
int len = Math.min(maxLen - baos.size(), result.length);
255268
bytesWritten = stream.inflater.inflate(result, 0, len);
256269
if (bytesWritten == 0 && !zdictIsSet && stream.inflater.needsDictionary()) {
257270
if (getZdict().length > 0) {
258271
setDictionary();
259272
zdictIsSet = true;
260-
// we inflate again with a dictionary
261-
bytesWritten = stream.inflater.inflate(result, 0, len);
273+
continue;
262274
} else {
263275
throw PRaiseNode.raiseStatic(nodeForRaise, ZLibError, WHILE_SETTING_ZDICT);
264276
}
@@ -320,7 +332,7 @@ protected static byte[] decompress(byte[] bytes, int length, int wbits, int bufs
320332
private void saveUnconsumedInput(byte[] data, int length,
321333
byte[] unusedDataBytes, int unconsumedTailLen, Node inliningTarget) {
322334
int unusedLen = getRemaining();
323-
byte[] tail = PythonUtils.arrayCopyOfRange(data, length - unusedLen, length);
335+
byte[] tail = PythonUtils.arrayCopyOfRange(data, Math.max(0, length - unusedLen), length);
324336
PythonLanguage language = PythonLanguage.get(inliningTarget);
325337
if (isEof()) {
326338
if (unconsumedTailLen > 0) {

0 commit comments

Comments
 (0)