Skip to content

Commit da8924a

Browse files
faran928meta-codesync[bot]
authored andcommitted
usharding <> lowering compatibility for CFR / IFR SIM (meta-pytorch#3512)
Summary: Pull Request resolved: meta-pytorch#3512 Lowering error thrown with torch.split(embeddings_i, length_per_key, dim=0) with lowering complaining about length_per_key being list. Wrapped it within a list. Even though we are calling torch.split twice for embeddings and values, there doesn't seem to be any regression due to this. Also almost all the models are not using values field (probably this is not compatible within inference / usharding currently). Reviewed By: Jason-KChen Differential Revision: D86117088 fbshipit-source-id: 9b23fb4e15cb668d8a4080d7692badd8d1336753
1 parent bb7299e commit da8924a

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

torchrec/distributed/quant_embedding.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,15 @@ def _get_unbucketize_tensor_via_length_alignment(
288288
return bucketize_permute_tensor
289289

290290

291+
@torch.fx.wrap
292+
def _fx_split_embeddings_per_feature_length(
293+
embeddings: torch.Tensor,
294+
features: KeyedJaggedTensor,
295+
) -> List[torch.Tensor]:
296+
length_per_key: List[int] = features.length_per_key()
297+
return embeddings.split(length_per_key, dim=0)
298+
299+
291300
def _construct_jagged_tensors_tw(
292301
embeddings: List[torch.Tensor],
293302
embedding_names_per_rank: List[List[str]],
@@ -307,13 +316,13 @@ def _construct_jagged_tensors_tw(
307316

308317
lengths = features_i.lengths().view(-1, features_i.stride())
309318
values = features_i.values()
310-
length_per_key = features_i.length_per_key()
311-
312-
embeddings_list = torch.split(embeddings_i, length_per_key, dim=0)
319+
embeddings_list = _fx_split_embeddings_per_feature_length(
320+
embeddings_i, features_i
321+
)
313322
stride = features_i.stride()
314323
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
315324
if need_indices:
316-
values_list = torch.split(values, length_per_key)
325+
values_list = _fx_split_embeddings_per_feature_length(values, features_i)
317326
for j, key in enumerate(embedding_names_per_rank[i]):
318327
ret[key] = JaggedTensor(
319328
lengths=lengths_tuple[j],

0 commit comments

Comments
 (0)