Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,71 @@ def init_parameters(self) -> None:
weight_init_max,
)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
def _get_hash_zch_identities(
self, features: KeyedJaggedTensor
) -> Optional[torch.Tensor]:
if self._raw_id_tracker_wrapper is None or not isinstance(
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
):
return None

raw_id_tracker_wrapper = self._raw_id_tracker_wrapper
assert (
raw_id_tracker_wrapper is not None
), "self._raw_id_tracker_wrapper should not be None"
assert hasattr(
self.emb_module, "res_params"
), "res_params should exist when raw_id_tracker is enabled"
res_params: RESParams = self.emb_module.res_params # pyre-ignore[9]
table_names = res_params.table_names

# TODO: get_indexed_lookups() may return multiple IndexedLookup objects
# across multiple training iterations. Current logic appends raw_ids from
# all batches sequentially. This may cause misalignment with
# features.values() which only contains the current batch.
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
table_names, self.emb_module.uuid
)

# Build hash_zch_identities by concatenating raw IDs from tracked tables.
# Output maintains 1-to-1 alignment with features.values().
# Iterate through table_names explicitly (not raw_ids_dict.values()) to
# ensure correct ordering, since there is no guarantee on dict ordering.
#
# E.g. If features.values() = [f1_val1, f1_val2, f2_val1, f2_val2, ...]
# where table1 has [feature1, feature2] and table2 has [feature3, feature4]
# then hash_zch_identities = [f1_id1, f1_id2, f2_id1, f2_id2, ...]
#
# TODO: Handle tables without identity tracking. Currently, only tables with
# raw_ids are included. If some tables lack identity while others have them,
# padding with -1 may be needed to maintain alignment.
all_raw_ids = []
for table_name in table_names:
if table_name in raw_ids_dict:
raw_ids_list = raw_ids_dict[table_name]
for raw_ids in raw_ids_list:
all_raw_ids.append(raw_ids)

if not all_raw_ids:
return None

hash_zch_identities = torch.cat(all_raw_ids)
assert hash_zch_identities.size(0) == features.values().numel(), (
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
)

return hash_zch_identities

def forward(
self,
features: KeyedJaggedTensor,
) -> torch.Tensor:
forward_args: Dict[str, Any] = {}
hash_zch_identities = self._get_hash_zch_identities(features)
if hash_zch_identities is not None:
forward_args["hash_zch_identities"] = hash_zch_identities

weights = features.weights_or_none()
if weights is not None and not torch.is_floating_point(weights):
weights = None
Expand All @@ -2617,17 +2681,22 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
SSDTableBatchedEmbeddingBags,
),
):
forward_args["batch_size_per_feature_per_rank"] = (
features.stride_per_key_per_rank()
)

if len(forward_args) == 0:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
per_sample_weights=weights,
batch_size_per_feature_per_rank=features.stride_per_key_per_rank(),
)
else:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
per_sample_weights=weights,
**forward_args,
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
Expand Down
Loading