@@ -2605,7 +2605,71 @@ def init_parameters(self) -> None:
26052605 weight_init_max ,
26062606 )
26072607
2608- def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
2608+ def _get_hash_zch_identities (
2609+ self , features : KeyedJaggedTensor
2610+ ) -> Optional [torch .Tensor ]:
2611+ if self ._raw_id_tracker_wrapper is None or not isinstance (
2612+ self .emb_module , SplitTableBatchedEmbeddingBagsCodegen
2613+ ):
2614+ return None
2615+
2616+ raw_id_tracker_wrapper = self ._raw_id_tracker_wrapper
2617+ assert (
2618+ raw_id_tracker_wrapper is not None
2619+ ), "self._raw_id_tracker_wrapper should not be None"
2620+ assert hasattr (
2621+ self .emb_module , "res_params"
2622+ ), "res_params should exist when raw_id_tracker is enabled"
2623+ res_params : RESParams = self .emb_module .res_params # pyre-ignore[9]
2624+ table_names = res_params .table_names
2625+
2626+ # TODO: get_indexed_lookups() may return multiple IndexedLookup objects
2627+ # across multiple training iterations. Current logic appends raw_ids from
2628+ # all batches sequentially. This may cause misalignment with
2629+ # features.values() which only contains the current batch.
2630+ raw_ids_dict = raw_id_tracker_wrapper .get_indexed_lookups (
2631+ table_names , self .emb_module .uuid
2632+ )
2633+
2634+ # Build hash_zch_identities by concatenating raw IDs from tracked tables.
2635+ # Output maintains 1-to-1 alignment with features.values().
2636+ # Iterate through table_names explicitly (not raw_ids_dict.values()) to
2637+ # ensure correct ordering, since there is no guarantee on dict ordering.
2638+ #
2639+ # E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
2640+ # where table1 has [feature1, feature2] and table2 has [feature3, feature4]
2641+ # then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
2642+ #
2643+ # TODO: Handle tables without identity tracking. Currently, only tables with
2644+ # raw_ids are included. If some tables lack identity while others have them,
2645+ # padding with -1 may be needed to maintain alignment.
2646+ all_raw_ids = []
2647+ for table_name in table_names :
2648+ if table_name in raw_ids_dict :
2649+ raw_ids_list = raw_ids_dict [table_name ]
2650+ for raw_ids in raw_ids_list :
2651+ all_raw_ids .append (raw_ids )
2652+
2653+ if not all_raw_ids :
2654+ return None
2655+
2656+ hash_zch_identities = torch .cat (all_raw_ids )
2657+ assert hash_zch_identities .size (0 ) == features .values ().numel (), (
2658+ f"hash_zch_identities row count ({ hash_zch_identities .size (0 )} ) must match "
2659+ f"features.values() length ({ features .values ().numel ()} ) to maintain 1-to-1 alignment"
2660+ )
2661+
2662+ return hash_zch_identities
2663+
2664+ def forward (
2665+ self ,
2666+ features : KeyedJaggedTensor ,
2667+ ) -> torch .Tensor :
2668+ forward_args : Dict [str , Any ] = {}
2669+ hash_zch_identities = self ._get_hash_zch_identities (features )
2670+ if hash_zch_identities is not None :
2671+ forward_args ["hash_zch_identities" ] = hash_zch_identities
2672+
26092673 weights = features .weights_or_none ()
26102674 if weights is not None and not torch .is_floating_point (weights ):
26112675 weights = None
@@ -2617,17 +2681,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
26172681 SSDTableBatchedEmbeddingBags ,
26182682 ),
26192683 ):
2684+ forward_args ["batch_size_per_feature_per_rank" ] = (
2685+ features .stride_per_key_per_rank ()
2686+ )
2687+
2688+ if len (forward_args ) == 0 :
26202689 return self .emb_module (
26212690 indices = features .values ().long (),
26222691 offsets = features .offsets ().long (),
26232692 per_sample_weights = weights ,
2624- batch_size_per_feature_per_rank = features .stride_per_key_per_rank (),
26252693 )
26262694 else :
26272695 return self .emb_module (
26282696 indices = features .values ().long (),
26292697 offsets = features .offsets ().long (),
26302698 per_sample_weights = weights ,
2699+ ** forward_args ,
26312700 )
26322701
26332702 # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
0 commit comments