From 5dc811d32c40db7b527aea05c8f419c71444088a Mon Sep 17 00:00:00 2001 From: Joey Yang Date: Mon, 10 Nov 2025 23:08:35 -0800 Subject: [PATCH] Look up and forward raw_ids from tracker to TBE from embedding module (#3527) Summary: We look up the identities corresponding to slot index from `raw_id_tracker` and forward it to TBE. The identities are stored in `raw_id_tracker` during `mc_module` lookup. `raw_ids` are retrieved per-table from `raw_id_tracker` with the api `get_indexed_lookup()`, we concatenat them to maintain 1-to-1 alignment with `features.values()`. Reviewed By: chouxi Differential Revision: D86242001 --- .../distributed/batched_embedding_kernel.py | 73 ++++++++++++++++++- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index ddef4be14..b8f3bf17b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -2605,7 +2605,71 @@ def init_parameters(self) -> None: weight_init_max, ) - def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + def _get_hash_zch_identities( + self, features: KeyedJaggedTensor + ) -> Optional[torch.Tensor]: + if self._raw_id_tracker_wrapper is None or not isinstance( + self.emb_module, SplitTableBatchedEmbeddingBagsCodegen + ): + return None + + raw_id_tracker_wrapper = self._raw_id_tracker_wrapper + assert ( + raw_id_tracker_wrapper is not None + ), "self._raw_id_tracker_wrapper should not be None" + assert hasattr( + self.emb_module, "res_params" + ), "res_params should exist when raw_id_tracker is enabled" + res_params: RESParams = self.emb_module.res_params # pyre-ignore[9] + table_names = res_params.table_names + + # TODO: get_indexed_lookups() may return multiple IndexedLookup objects + # across multiple training iterations. Current logic appends raw_ids from + # all batches sequentially. This may cause misalignment with + # features.values() which only contains the current batch. + raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups( + table_names, self.emb_module.uuid + ) + + # Build hash_zch_identities by concatenating raw IDs from tracked tables. + # Output maintains 1-to-1 alignment with features.values(). + # Iterate through table_names explicitly (not raw_ids_dict.values()) to + # ensure correct ordering, since there is no guarantee on dict ordering. + # + # E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...] + # where table1 has [feature1, feature2] and table2 has [feature3, feature4] + # then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...] + # + # TODO: Handle tables without identity tracking. Currently, only tables with + # raw_ids are included. If some tables lack identity while others have them, + # padding with -1 may be needed to maintain alignment. + all_raw_ids = [] + for table_name in table_names: + if table_name in raw_ids_dict: + raw_ids_list = raw_ids_dict[table_name] + for raw_ids in raw_ids_list: + all_raw_ids.append(raw_ids) + + if not all_raw_ids: + return None + + hash_zch_identities = torch.cat(all_raw_ids) + assert hash_zch_identities.size(0) == features.values().numel(), ( + f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match " + f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment" + ) + + return hash_zch_identities + + def forward( + self, + features: KeyedJaggedTensor, + ) -> torch.Tensor: + forward_args: Dict[str, Any] = {} + hash_zch_identities = self._get_hash_zch_identities(features) + if hash_zch_identities is not None: + forward_args["hash_zch_identities"] = hash_zch_identities + weights = features.weights_or_none() if weights is not None and not torch.is_floating_point(weights): weights = None @@ -2617,17 +2681,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: SSDTableBatchedEmbeddingBags, ), ): + forward_args["batch_size_per_feature_per_rank"] = ( + features.stride_per_key_per_rank() + ) + + if len(forward_args) == 0: return self.emb_module( indices=features.values().long(), offsets=features.offsets().long(), per_sample_weights=weights, - batch_size_per_feature_per_rank=features.stride_per_key_per_rank(), ) else: return self.emb_module( indices=features.values().long(), offsets=features.offsets().long(), per_sample_weights=weights, + **forward_args, ) # pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.