@@ -8,41 +8,74 @@ static VALUE zstdVersion(VALUE self)
88 return INT2NUM (version );
99}
1010
11- static VALUE compress (int argc , VALUE * argv , VALUE self )
11+ static VALUE rb_compress (int argc , VALUE * argv , VALUE self )
1212{
1313 VALUE input_value ;
1414 VALUE compression_level_value ;
1515 rb_scan_args (argc , argv , "11" , & input_value , & compression_level_value );
16+ int compression_level = convert_compression_level (compression_level_value );
1617
1718 StringValue (input_value );
18- const char * input_data = RSTRING_PTR (input_value );
19+ char * input_data = RSTRING_PTR (input_value );
1920 size_t input_size = RSTRING_LEN (input_value );
21+ size_t max_compressed_size = ZSTD_compressBound (input_size );
2022
21- int compression_level ;
22- if (NIL_P (compression_level_value )) {
23- compression_level = 0 ; // The default. See ZSTD_CLEVEL_DEFAULT in zstd_compress.c
24- } else {
25- compression_level = NUM2INT (compression_level_value );
23+ VALUE output = rb_str_new (NULL , max_compressed_size );
24+ char * output_data = RSTRING_PTR (output );
25+ size_t compressed_size = ZSTD_compress ((void * )output_data , max_compressed_size ,
26+ (void * )input_data , input_size , compression_level );
27+ if (ZSTD_isError (compressed_size )) {
28+ rb_raise (rb_eRuntimeError , "%s: %s" , "compress failed" , ZSTD_getErrorName (compressed_size ));
2629 }
2730
28- // do compress
31+ rb_str_resize (output , compressed_size );
32+ return output ;
33+ }
34+
35+ static VALUE rb_compress_using_dict (int argc , VALUE * argv , VALUE self )
36+ {
37+ VALUE input_value ;
38+ VALUE dict ;
39+ VALUE compression_level_value ;
40+ rb_scan_args (argc , argv , "21" , & input_value , & dict , & compression_level_value );
41+ int compression_level = convert_compression_level (compression_level_value );
42+
43+ StringValue (input_value );
44+ char * input_data = RSTRING_PTR (input_value );
45+ size_t input_size = RSTRING_LEN (input_value );
2946 size_t max_compressed_size = ZSTD_compressBound (input_size );
3047
48+ char * dict_buffer = RSTRING_PTR (dict );
49+ size_t dict_size = RSTRING_LEN (dict );
50+
51+ ZSTD_CDict * const cdict = ZSTD_createCDict (dict_buffer , dict_size , compression_level );
52+ if (cdict == NULL ) {
53+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createCDict failed" );
54+ }
55+ ZSTD_CCtx * const ctx = ZSTD_createCCtx ();
56+ if (ctx == NULL ) {
57+ ZSTD_freeCDict (cdict );
58+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createCCtx failed" );
59+ }
60+
3161 VALUE output = rb_str_new (NULL , max_compressed_size );
3262 char * output_data = RSTRING_PTR (output );
33-
34- size_t compressed_size = ZSTD_compress ((void * )output_data , max_compressed_size ,
35- (const void * )input_data , input_size , compression_level );
63+ size_t const compressed_size = ZSTD_compress_usingCDict (ctx , (void * )output_data , max_compressed_size ,
64+ (void * )input_data , input_size , cdict );
3665
3766 if (ZSTD_isError (compressed_size )) {
67+ ZSTD_freeCDict (cdict );
68+ ZSTD_freeCCtx (ctx );
3869 rb_raise (rb_eRuntimeError , "%s: %s" , "compress failed" , ZSTD_getErrorName (compressed_size ));
39- } else {
40- rb_str_resize (output , compressed_size );
4170 }
4271
72+ rb_str_resize (output , compressed_size );
73+ ZSTD_freeCDict (cdict );
74+ ZSTD_freeCCtx (ctx );
4375 return output ;
4476}
4577
78+
4679static VALUE decompress_buffered (const char * input_data , size_t input_size )
4780{
4881 const size_t outputBufferSize = 4096 ;
@@ -58,7 +91,6 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
5891 rb_raise (rb_eRuntimeError , "%s: %s" , "ZSTD_initDStream failed" , ZSTD_getErrorName (initResult ));
5992 }
6093
61-
6294 VALUE output_string = rb_str_new (NULL , 0 );
6395 ZSTD_outBuffer output = { NULL , 0 , 0 };
6496
@@ -80,23 +112,24 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
80112 return output_string ;
81113}
82114
83- static VALUE decompress (VALUE self , VALUE input )
115+ static VALUE rb_decompress (VALUE self , VALUE input_value )
84116{
85- StringValue (input );
86- const char * input_data = RSTRING_PTR (input );
87- size_t input_size = RSTRING_LEN (input );
88-
89- uint64_t uncompressed_size = ZSTD_getDecompressedSize (input_data , input_size );
117+ StringValue (input_value );
118+ char * input_data = RSTRING_PTR (input_value );
119+ size_t input_size = RSTRING_LEN (input_value );
90120
91- if (uncompressed_size == 0 ) {
121+ unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
122+ if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
123+ rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
124+ }
125+ if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
92126 return decompress_buffered (input_data , input_size );
93127 }
94128
95129 VALUE output = rb_str_new (NULL , uncompressed_size );
96130 char * output_data = RSTRING_PTR (output );
97-
98- size_t decompress_size = ZSTD_decompress ((void * )output_data , uncompressed_size ,
99- (const void * )input_data , input_size );
131+ size_t const decompress_size = ZSTD_decompress ((void * )output_data , uncompressed_size ,
132+ (void * )input_data , input_size );
100133
101134 if (ZSTD_isError (decompress_size )) {
102135 rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
@@ -105,10 +138,61 @@ static VALUE decompress(VALUE self, VALUE input)
105138 return output ;
106139}
107140
141+ static VALUE rb_decompress_using_dict (int argc , VALUE * argv , VALUE self )
142+ {
143+ VALUE input_value ;
144+ VALUE dict ;
145+ rb_scan_args (argc , argv , "20" , & input_value , & dict );
146+
147+ StringValue (input_value );
148+ char * input_data = RSTRING_PTR (input_value );
149+ size_t input_size = RSTRING_LEN (input_value );
150+ unsigned long long const uncompressed_size = ZSTD_getFrameContentSize (input_data , input_size );
151+ if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR ) {
152+ rb_raise (rb_eRuntimeError , "%s: %s" , "not compressed by zstd" , ZSTD_getErrorName (uncompressed_size ));
153+ }
154+ if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN ) {
155+ return decompress_buffered (input_data , input_size );
156+ }
157+ VALUE output = rb_str_new (NULL , uncompressed_size );
158+ char * output_data = RSTRING_PTR (output );
159+
160+ char * dict_buffer = RSTRING_PTR (dict );
161+ size_t dict_size = RSTRING_LEN (dict );
162+ ZSTD_DDict * const ddict = ZSTD_createDDict (dict_buffer , dict_size );
163+ if (ddict == NULL ) {
164+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDDict failed" );
165+ }
166+
167+ unsigned const expected_dict_id = ZSTD_getDictID_fromDDict (ddict );
168+ unsigned const actual_dict_id = ZSTD_getDictID_fromFrame (input_data , input_size );
169+ if (expected_dict_id != actual_dict_id ) {
170+ ZSTD_freeDDict (ddict );
171+ rb_raise (rb_eRuntimeError , "%s: %s" , "DictID mismatch" , ZSTD_getErrorName (uncompressed_size ));
172+ }
173+
174+ ZSTD_DCtx * const ctx = ZSTD_createDCtx ();
175+ if (ctx == NULL ) {
176+ ZSTD_freeDDict (ddict );
177+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDCtx failed" );
178+ }
179+ size_t const decompress_size = ZSTD_decompress_usingDDict (ctx , output_data , uncompressed_size , input_data , input_size , ddict );
180+ if (ZSTD_isError (decompress_size )) {
181+ ZSTD_freeDDict (ddict );
182+ ZSTD_freeDCtx (ctx );
183+ rb_raise (rb_eRuntimeError , "%s: %s" , "decompress error" , ZSTD_getErrorName (decompress_size ));
184+ }
185+ ZSTD_freeDDict (ddict );
186+ ZSTD_freeDCtx (ctx );
187+ return output ;
188+ }
189+
108190void
109191zstd_ruby_init (void )
110192{
111193 rb_define_module_function (rb_mZstd , "zstd_version" , zstdVersion , 0 );
112- rb_define_module_function (rb_mZstd , "compress" , compress , -1 );
113- rb_define_module_function (rb_mZstd , "decompress" , decompress , 1 );
194+ rb_define_module_function (rb_mZstd , "compress" , rb_compress , -1 );
195+ rb_define_module_function (rb_mZstd , "compress_using_dict" , rb_compress_using_dict , -1 );
196+ rb_define_module_function (rb_mZstd , "decompress" , rb_decompress , 1 );
197+ rb_define_module_function (rb_mZstd , "decompress_using_dict" , rb_decompress_using_dict , -1 );
114198}
0 commit comments