@@ -195,6 +195,90 @@ static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self)
195195 return output ;
196196}
197197
198+ static void free_cdict (void * dict )
199+ {
200+ ZSTD_freeCDict (dict );
201+ }
202+
203+ static size_t sizeof_cdict (const void * dict )
204+ {
205+ return ZSTD_sizeof_CDict (dict );
206+ }
207+
208+ static void free_ddict (void * dict )
209+ {
210+ ZSTD_freeDDict (dict );
211+ }
212+
213+ static size_t sizeof_ddict (const void * dict )
214+ {
215+ return ZSTD_sizeof_DDict (dict );
216+ }
217+
218+ static const rb_data_type_t cdict_type = {
219+ "Zstd::CDict" ,
220+ {0 , free_cdict , sizeof_cdict ,},
221+ 0 , 0 , RUBY_TYPED_FREE_IMMEDIATELY
222+ };
223+
224+ static const rb_data_type_t ddict_type = {
225+ "Zstd::DDict" ,
226+ {0 , free_ddict , sizeof_ddict ,},
227+ 0 , 0 , RUBY_TYPED_FREE_IMMEDIATELY
228+ };
229+
230+ static VALUE rb_cdict_alloc (VALUE self )
231+ {
232+ ZSTD_CDict * cdict = NULL ;
233+ return TypedData_Wrap_Struct (self , & cdict_type , cdict );
234+ }
235+
236+ static VALUE rb_cdict_initialize (int argc , VALUE * argv , VALUE self )
237+ {
238+ VALUE dict ;
239+ VALUE compression_level_value ;
240+ rb_scan_args (argc , argv , "11" , & dict , & compression_level_value );
241+ int compression_level = convert_compression_level (compression_level_value );
242+
243+ StringValue (dict );
244+ char * dict_buffer = RSTRING_PTR (dict );
245+ size_t dict_size = RSTRING_LEN (dict );
246+
247+ ZSTD_CDict * const cdict = ZSTD_createCDict (dict_buffer , dict_size , compression_level );
248+ if (cdict == NULL ) {
249+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createCDict failed" );
250+ }
251+
252+ DATA_PTR (self ) = cdict ;
253+ return self ;
254+ }
255+
256+ static VALUE rb_ddict_alloc (VALUE self )
257+ {
258+ ZSTD_CDict * ddict = NULL ;
259+ return TypedData_Wrap_Struct (self , & ddict_type , ddict );
260+ }
261+
262+ static VALUE rb_ddict_initialize (VALUE self , VALUE dict )
263+ {
264+ StringValue (dict );
265+ char * dict_buffer = RSTRING_PTR (dict );
266+ size_t dict_size = RSTRING_LEN (dict );
267+
268+ ZSTD_DDict * const ddict = ZSTD_createDDict (dict_buffer , dict_size );
269+ if (ddict == NULL ) {
270+ rb_raise (rb_eRuntimeError , "%s" , "ZSTD_createDDict failed" );
271+ }
272+
273+ DATA_PTR (self ) = ddict ;
274+ return self ;
275+ }
276+
277+ static VALUE rb_prohibit_copy (VALUE , VALUE )
278+ {
279+ rb_raise (rb_eRuntimeError , "CDict cannot be duplicated" );
280+ }
281+
198282void
199283zstd_ruby_init (void )
200284{
@@ -203,4 +287,12 @@ zstd_ruby_init(void)
203287 rb_define_module_function (rb_mZstd , "compress_using_dict" , rb_compress_using_dict , -1 );
204288 rb_define_module_function (rb_mZstd , "decompress" , rb_decompress , -1 );
205289 rb_define_module_function (rb_mZstd , "decompress_using_dict" , rb_decompress_using_dict , -1 );
290+
291+ rb_define_alloc_func (rb_cCDict , rb_cdict_alloc );
292+ rb_define_private_method (rb_cCDict , "initialize" , rb_cdict_initialize , -1 );
293+ rb_define_method (rb_cCDict , "initialize_copy" , rb_prohibit_copy , 1 );
294+
295+ rb_define_alloc_func (rb_cDDict , rb_ddict_alloc );
296+ rb_define_private_method (rb_cDDict , "initialize" , rb_ddict_initialize , 1 );
297+ rb_define_method (rb_cDDict , "initialize_copy" , rb_prohibit_copy , 1 );
206298}
0 commit comments