@@ -26,6 +26,7 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
2626 char * input_data = RSTRING_PTR (input_value );
2727 size_t input_size = RSTRING_LEN (input_value );
2828 ZSTD_inBuffer input = { input_data , input_size , 0 };
29+ // ZSTD_compressBound causes SEGV under multi-thread
2930 size_t max_compressed_size = ZSTD_compressBound (input_size );
3031 VALUE buf = rb_str_new (NULL , max_compressed_size );
3132 char * output_data = RSTRING_PTR (buf );
@@ -87,19 +88,8 @@ static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self)
8788}
8889
8990
90- static VALUE decompress_buffered (const char * input_data , size_t input_size )
91+ static VALUE decompress_buffered (ZSTD_DCtx * dctx , const char * input_data , size_t input_size )
9192{
92- ZSTD_DStream * const dstream = ZSTD_createDStream ();
93- if (dstream == NULL ) {
94- rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDStream failed" );
95- }
96-
97- size_t initResult = ZSTD_initDStream (dstream );
98- if (ZSTD_isError (initResult )) {
99- ZSTD_freeDStream (dstream );
100- rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_initDStream failed" , ZSTD_getErrorName (initResult ));
101- }
102-
10393 VALUE output_string = rb_str_new (NULL , 0 );
10494 ZSTD_outBuffer output = { NULL , 0 , 0 };
10595
@@ -109,15 +99,14 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
10999 rb_str_resize (output_string , output .size );
110100 output .dst = RSTRING_PTR (output_string );
111101
112- size_t readHint = ZSTD_decompressStream (dstream , & output , & input );
113- if (ZSTD_isError (readHint )) {
114- ZSTD_freeDStream ( dstream );
115- rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_decompressStream failed" , ZSTD_getErrorName (readHint ));
102+ size_t ret = ZSTD_decompressStream (dctx , & output , & input );
103+ if (ZSTD_isError (ret )) {
104+ ZSTD_freeDCtx ( dctx );
105+ rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_decompressStream failed" , ZSTD_getErrorName (ret ));
116106 }
117107 }
118-
119- ZSTD_freeDStream (dstream );
120108 rb_str_resize (output_string , output .pos );
109+ ZSTD_freeDCtx (dctx );
121110 return output_string ;
122111}
123112
@@ -129,6 +118,11 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
129118 StringValue (input_value );
130119 char * input_data = RSTRING_PTR (input_value );
131120 size_t input_size = RSTRING_LEN (input_value );
121+ ZSTD_DCtx * const dctx = ZSTD_createDCtx ();
122+ if (dctx == NULL ) {
123+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
124+ }
125+ set_decompress_params (dctx , kwargs );
132126
133127 unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
134128 if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
@@ -137,15 +131,9 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
137131 // ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
138132 // Therefore, we will not standardize on ZSTD_decompressStream
139133 if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
140- return decompress_buffered (input_data , input_size );
134+ return decompress_buffered (dctx , input_data , input_size );
141135 }
142136
143- ZSTD_DCtx * const dctx = ZSTD_createDCtx ();
144- if (dctx == NULL ) {
145- rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
146- }
147- set_decompress_params (dctx , kwargs );
148-
149137 VALUE output = rb_str_new (NULL , uncompressed_size );
150138 char * output_data = RSTRING_PTR (output );
151139
@@ -167,35 +155,38 @@ static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self)
167155 StringValue (input_value );
168156 char * input_data = RSTRING_PTR (input_value );
169157 size_t input_size = RSTRING_LEN (input_value );
170- unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
171- if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
172- rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
173- }
174- if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
175- return decompress_buffered (input_data , input_size );
176- }
177- VALUE output = rb_str_new (NULL , uncompressed_size );
178- char * output_data = RSTRING_PTR (output );
179158
180159 char * dict_buffer = RSTRING_PTR (dict );
181160 size_t dict_size = RSTRING_LEN (dict );
182161 ZSTD_DDict * const ddict = ZSTD_createDDict (dict_buffer , dict_size );
183162 if (ddict == NULL ) {
184163 rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDDict failed" );
185164 }
186-
187165 unsigned const expected_dict_id = ZSTD_getDictID_fromDDict (ddict );
188166 unsigned const actual_dict_id = ZSTD_getDictID_fromFrame (input_data , input_size );
189167 if (expected_dict_id != actual_dict_id ) {
190168 ZSTD_freeDDict (ddict );
191- rb_raise (rb_eRuntimeError , "%s: %s" , " DictID mismatch", ZSTD_getErrorName ( uncompressed_size ) );
169+ rb_raise (rb_eRuntimeError , "DictID mismatch" );
192170 }
193171
194172 ZSTD_DCtx * const ctx = ZSTD_createDCtx ();
195173 if (ctx == NULL ) {
196174 ZSTD_freeDDict (ddict );
197175 rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
198176 }
177+
178+ unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
179+ if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
180+ ZSTD_freeDDict (ddict );
181+ ZSTD_freeDCtx (ctx );
182+ rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
183+ }
184+ if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
185+ return decompress_buffered (ctx , input_data , input_size );
186+ }
187+
188+ VALUE output = rb_str_new (NULL , uncompressed_size );
189+ char * output_data = RSTRING_PTR (output );
199190 size_t const decompress_size = ZSTD_decompress_usingDDict (ctx , output_data , uncompressed_size , input_data , input_size , ddict );
200191 if (ZSTD_isError (decompress_size )) {
201192 ZSTD_freeDDict (ddict );
0 commit comments