@@ -207,8 +207,14 @@ def forward(
207207 num_embeddings = 10 ,
208208 feature_names = ["f2" ],
209209 )
210+ config3 = EmbeddingBagConfig (
211+ name = "t3" ,
212+ embedding_dim = 5 ,
213+ num_embeddings = 10 ,
214+ feature_names = ["f3" ],
215+ )
210216 ebc = EmbeddingBagCollection (
211- tables = [config1 , config2 ],
217+ tables = [config1 , config2 , config3 ],
212218 is_weighted = False ,
213219 )
214220
@@ -293,42 +299,60 @@ def test_serialize_deserialize_ebc(self) -> None:
293299 self .assertEqual (deserialized .shape , orginal .shape )
294300 self .assertTrue (torch .allclose (deserialized , orginal ))
295301
296- @unittest .skip ("Adding test for demonstrating VBE KJT flattening issue for now." )
297302 def test_serialize_deserialize_ebc_with_vbe_kjt (self ) -> None :
298303 model = self .generate_model_for_vbe_kjt ()
299- id_list_features = KeyedJaggedTensor (
300- keys = ["f1" , "f2" ],
301- values = torch .tensor ([5 , 6 , 7 , 1 , 2 , 3 , 0 , 1 ]),
302- lengths = torch .tensor ([3 , 3 , 2 ]),
303- stride_per_key_per_rank = [[2 ], [1 ]],
304- inverse_indices = (["f1" , "f2" ], torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])),
304+ kjt_1 = KeyedJaggedTensor (
305+ keys = ["f1" , "f2" , "f3" ],
306+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ]),
307+ lengths = torch .tensor ([1 , 2 , 3 , 2 , 1 , 1 ]),
308+ stride_per_key_per_rank = torch .tensor ([[3 ], [2 ], [1 ]]),
309+ inverse_indices = (
310+ ["f1" , "f2" , "f3" ],
311+ torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ], [0 , 0 , 0 ]]),
312+ ),
313+ )
314+ kjt_2 = KeyedJaggedTensor (
315+ keys = ["f1" , "f2" , "f3" ],
316+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ]),
317+ lengths = torch .tensor ([1 , 2 , 3 , 2 , 1 , 1 ]),
318+ stride_per_key_per_rank = torch .tensor ([[1 ], [2 ], [3 ]]),
319+ inverse_indices = (
320+ ["f1" , "f2" , "f3" ],
321+ torch .tensor ([[0 , 0 , 0 ], [0 , 1 , 0 ], [0 , 1 , 2 ]]),
322+ ),
305323 )
306324
307- eager_out = model (id_list_features )
325+ eager_out = model (kjt_1 )
326+ eager_out_2 = model (kjt_2 )
308327
309328 # Serialize EBC
310329 model , sparse_fqns = encapsulate_ir_modules (model , JsonSerializer )
311330 ep = torch .export .export (
312331 model ,
313- (id_list_features ,),
332+ (kjt_1 ,),
314333 {},
315334 strict = False ,
316335 # Allows KJT to not be unflattened and run a forward on unflattened EP
317336 preserve_module_call_signature = (tuple (sparse_fqns )),
318337 )
319338
320339 # Run forward on ExportedProgram
321- ep_output = ep .module ()(id_list_features )
340+ ep_output = ep .module ()(kjt_1 )
341+ ep_output_2 = ep .module ()(kjt_2 )
322342
343+ self .assertEqual (len (ep_output ), len (kjt_1 .keys ()))
344+ self .assertEqual (len (ep_output_2 ), len (kjt_2 .keys ()))
323345 for i , tensor in enumerate (ep_output ):
324- self .assertEqual (eager_out [i ].shape , tensor .shape )
346+ self .assertEqual (eager_out [i ].shape [1 ], tensor .shape [1 ])
347+ for i , tensor in enumerate (ep_output_2 ):
348+ self .assertEqual (eager_out_2 [i ].shape [1 ], tensor .shape [1 ])
325349
326350 # Deserialize EBC
327351 unflatten_ep = torch .export .unflatten (ep )
328352 deserialized_model = decapsulate_ir_modules (unflatten_ep , JsonSerializer )
329353
330354 # check EBC config
331- for i in range (5 ):
355+ for i in range (1 ):
332356 ebc_name = f"ebc{ i + 1 } "
333357 self .assertIsInstance (
334358 getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
@@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
343367 self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
344368 self .assertEqual (deserialized .feature_names , orginal .feature_names )
345369
346- # check FPEBC config
347- for i in range (2 ):
348- fpebc_name = f"fpebc{ i + 1 } "
349- assert isinstance (
350- getattr (deserialized_model , fpebc_name ),
351- FeatureProcessedEmbeddingBagCollection ,
352- )
353-
354- for deserialized , orginal in zip (
355- getattr (
356- deserialized_model , fpebc_name
357- )._embedding_bag_collection .embedding_bag_configs (),
358- getattr (
359- model , fpebc_name
360- )._embedding_bag_collection .embedding_bag_configs (),
361- ):
362- self .assertEqual (deserialized .name , orginal .name )
363- self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
364- self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
365- self .assertEqual (deserialized .feature_names , orginal .feature_names )
366-
367370 # Run forward on deserialized model and compare the output
368371 deserialized_model .load_state_dict (model .state_dict ())
369- deserialized_out = deserialized_model (id_list_features )
372+ deserialized_out = deserialized_model (kjt_1 )
370373
371374 self .assertEqual (len (deserialized_out ), len (eager_out ))
372375 for deserialized , orginal in zip (deserialized_out , eager_out ):
373376 self .assertEqual (deserialized .shape , orginal .shape )
374377 self .assertTrue (torch .allclose (deserialized , orginal ))
375378
379+ deserialized_out_2 = deserialized_model (kjt_2 )
380+
381+ self .assertEqual (len (deserialized_out_2 ), len (eager_out_2 ))
382+ for deserialized , orginal in zip (deserialized_out_2 , eager_out_2 ):
383+ self .assertEqual (deserialized .shape , orginal .shape )
384+ self .assertTrue (torch .allclose (deserialized , orginal ))
385+
376386 def test_dynamic_shape_ebc_disabled_in_oss_compatibility (self ) -> None :
377387 model = self .generate_model ()
378388 feature1 = KeyedJaggedTensor .from_offsets_sync (
0 commit comments