Skip to content

Commit d6ee3e0

Browse files
Joey Yangmeta-codesync[bot]
authored andcommitted
Look up and forward raw_ids from tracker to TBE from embedding module (#3527)
Summary: Pull Request resolved: #3527 We look up the identities corresponding to slot index from `raw_id_tracker` and forward it to TBE. The identities are stored in `raw_id_tracker` during `mc_module` lookup. `raw_ids` are retrieved per-table from `raw_id_tracker` with the api `get_indexed_lookup()`, we concatenat them to maintain 1-to-1 alignment with `features.values()`. Reviewed By: emlin, chouxi Differential Revision: D86242001 fbshipit-source-id: b3b3e1ee4a6b966025e83a2aa0a534d4b2ef637a
1 parent a34ff1e commit d6ee3e0

File tree

1 file changed

+71
-2
lines changed

1 file changed

+71
-2
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2605,7 +2605,71 @@ def init_parameters(self) -> None:
26052605
weight_init_max,
26062606
)
26072607

2608-
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
2608+
def _get_hash_zch_identities(
2609+
self, features: KeyedJaggedTensor
2610+
) -> Optional[torch.Tensor]:
2611+
if self._raw_id_tracker_wrapper is None or not isinstance(
2612+
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
2613+
):
2614+
return None
2615+
2616+
raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
2617+
assert (
2618+
raw_id_tracker_wrapper is not None
2619+
), "self._raw_id_tracker_wrapper should not be None"
2620+
assert hasattr(
2621+
self.emb_module, "res_params"
2622+
), "res_params should exist when raw_id_tracker is enabled"
2623+
res_params: RESParams = self.emb_module.res_params # pyre-ignore[9]
2624+
table_names = res_params.table_names
2625+
2626+
# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
2627+
# across multiple training iterations. Current logic appends raw_ids from
2628+
# all batches sequentially. This may cause misalignment with
2629+
# features.values() which only contains the current batch.
2630+
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
2631+
table_names, self.emb_module.uuid
2632+
)
2633+
2634+
# Build hash_zch_identities by concatenating raw IDs from tracked tables.
2635+
# Output maintains 1-to-1 alignment with features.values().
2636+
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
2637+
# ensure correct ordering, since there is no guarantee on dict ordering.
2638+
#
2639+
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
2640+
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
2641+
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
2642+
#
2643+
# TODO: Handle tables without identity tracking. Currently, only tables with
2644+
# raw_ids are included. If some tables lack identity while others have them,
2645+
# padding with -1 may be needed to maintain alignment.
2646+
all_raw_ids = []
2647+
for table_name in table_names:
2648+
if table_name in raw_ids_dict:
2649+
raw_ids_list = raw_ids_dict[table_name]
2650+
for raw_ids in raw_ids_list:
2651+
all_raw_ids.append(raw_ids)
2652+
2653+
if not all_raw_ids:
2654+
return None
2655+
2656+
hash_zch_identities = torch.cat(all_raw_ids)
2657+
assert hash_zch_identities.size(0) == features.values().numel(), (
2658+
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
2659+
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
2660+
)
2661+
2662+
return hash_zch_identities
2663+
2664+
def forward(
2665+
self,
2666+
features: KeyedJaggedTensor,
2667+
) -> torch.Tensor:
2668+
forward_args: Dict[str, Any] = {}
2669+
hash_zch_identities = self._get_hash_zch_identities(features)
2670+
if hash_zch_identities is not None:
2671+
forward_args["hash_zch_identities"] = hash_zch_identities
2672+
26092673
weights = features.weights_or_none()
26102674
if weights is not None and not torch.is_floating_point(weights):
26112675
weights = None
@@ -2617,17 +2681,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
26172681
SSDTableBatchedEmbeddingBags,
26182682
),
26192683
):
2684+
forward_args["batch_size_per_feature_per_rank"] = (
2685+
features.stride_per_key_per_rank()
2686+
)
2687+
2688+
if len(forward_args) == 0:
26202689
return self.emb_module(
26212690
indices=features.values().long(),
26222691
offsets=features.offsets().long(),
26232692
per_sample_weights=weights,
2624-
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
26252693
)
26262694
else:
26272695
return self.emb_module(
26282696
indices=features.values().long(),
26292697
offsets=features.offsets().long(),
26302698
per_sample_weights=weights,
2699+
**forward_args,
26312700
)
26322701

26332702
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.

0 commit comments

Comments
 (0)