Skip to content

Commit 570eb1c

Browse files
Yixin Baometa-codesync[bot]
authored andcommitted
Add an interface in mc modules class to return a mask tensor. (#3496)
Summary: Pull Request resolved: #3496 Reviewed By: zlzhao1104 Differential Revision: D85837187 fbshipit-source-id: 34451984a6f7ec262e2aad60e6386da4fcb52514
1 parent ce16675 commit 570eb1c

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

torchrec/modules/mc_embedding_modules.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ def forward(
9393
return embedding_res, None
9494
return embedding_res, features
9595

96+
def lookup_remapped_lengths_mask(
97+
self,
98+
features: KeyedJaggedTensor,
99+
) -> torch.Tensor:
100+
features = self._managed_collision_collection(features)
101+
remapped_lengths = return_remapped_lengths_as_mask(features)
102+
return remapped_lengths
103+
104+
105+
@torch.fx.wrap
106+
def return_remapped_lengths_as_mask(features: KeyedJaggedTensor) -> torch.Tensor:
107+
return features.lengths().to(torch.bool)
108+
96109

97110
class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollection):
98111
"""

torchrec/modules/tests/test_mc_embedding_modules.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchrec.modules.mc_embedding_modules import (
2222
ManagedCollisionEmbeddingBagCollection,
2323
ManagedCollisionEmbeddingCollection,
24+
return_remapped_lengths_as_mask,
2425
)
2526
from torchrec.modules.mc_modules import (
2627
DistanceLFU_EvictionPolicy,
@@ -409,3 +410,20 @@ def test_mc_collection_traceable(self) -> None:
409410
)
410411
mcc.train(False)
411412
symbolic_trace(mcc, leaf_modules=[ComputeJTDictToKJT.__name__])
413+
414+
def test_return_remapped_lengths_as_mask(self) -> None:
415+
mask = return_remapped_lengths_as_mask(
416+
KeyedJaggedTensor(
417+
keys=["f0"],
418+
values=torch.rand(6),
419+
lengths=torch.tensor([1, 0, 1, 0, 1, 0, 0, 1, 1, 1], dtype=torch.int64),
420+
)
421+
)
422+
self.assertTrue(
423+
torch.equal(
424+
mask,
425+
torch.tensor(
426+
[True, False, True, False, True, False, False, True, True, True]
427+
),
428+
)
429+
)

0 commit comments

Comments
 (0)