Skip to content

Commit 521f159

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Updating record_lookup function signature to accommodate future implementations (#3471)
Summary: Pull Request resolved: #3471 internal General Context: We are in the process of transition to a unified DeltaTracker and this is 4/n diffs representing changes towards the transition. Specific Context: Update record_lookup function signature to accommodate MRS DeltaTracker implementation Differential Revision: D80614980 fbshipit-source-id: 84f874668f7f5a5916611c93e58577d6f5dc00bc
1 parent 3ede596 commit 521f159

File tree

5 files changed

+29
-17
lines changed

5 files changed

+29
-17
lines changed

torchrec/distributed/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@ def compute_and_output_dist(
15871587
):
15881588
embs = lookup(features)
15891589
if self.post_lookup_tracker_fn is not None:
1590-
self.post_lookup_tracker_fn(self, features, embs)
1590+
self.post_lookup_tracker_fn(features, embs, self)
15911591

15921592
with maybe_annotate_embedding_event(
15931593
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type

torchrec/distributed/embedding_lookup.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
self.grouped_configs = grouped_configs
211211
# Model tracker function to tracker optimizer state
212212
self.optim_state_tracker_fn: Optional[
213-
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
213+
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
214214
] = None
215215

216216
def _create_embedding_kernel(
@@ -325,7 +325,7 @@ def forward(
325325
# Model tracker optimizer state function, will only be set called
326326
# when model tracker is configured to track optimizer state
327327
if self.optim_state_tracker_fn is not None:
328-
self.optim_state_tracker_fn(emb_op, features, lookup)
328+
self.optim_state_tracker_fn(features, lookup, emb_op)
329329

330330
return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)
331331

@@ -432,13 +432,15 @@ def purge(self) -> None:
432432

433433
def register_optim_state_tracker_fn(
434434
self,
435-
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
435+
record_fn: Callable[
436+
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
437+
],
436438
) -> None:
437439
"""
438440
Model tracker function to tracker optimizer state
439441
440442
Args:
441-
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
443+
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done.
442444
443445
"""
444446
self.optim_state_tracker_fn = record_fn
@@ -544,7 +546,7 @@ def __init__(
544546
)
545547
# Model tracker function to tracker optimizer state
546548
self.optim_state_tracker_fn: Optional[
547-
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
549+
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
548550
] = None
549551

550552
def _create_embedding_kernel(
@@ -710,7 +712,7 @@ def forward(
710712
# Model tracker optimizer state function, will only be set called
711713
# when model tracker is configured to track optimizer state
712714
if self.optim_state_tracker_fn is not None:
713-
self.optim_state_tracker_fn(emb_op, features, lookup)
715+
self.optim_state_tracker_fn(features, lookup, emb_op)
714716

715717
if features.variable_stride_per_key() and len(self._emb_modules) > 1:
716718
stride_per_rank_per_key = list(
@@ -845,13 +847,15 @@ def purge(self) -> None:
845847

846848
def register_optim_state_tracker_fn(
847849
self,
848-
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
850+
record_fn: Callable[
851+
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
852+
],
849853
) -> None:
850854
"""
851855
Model tracker function to tracker optimizer state
852856
853857
Args:
854-
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
858+
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]): A custom record function to be called after lookup is done.
855859
856860
"""
857861
self.optim_state_tracker_fn = record_fn

torchrec/distributed/embedding_types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def __init__(
391391
self._lookups: List[nn.Module] = []
392392
self._output_dists: List[nn.Module] = []
393393
self.post_lookup_tracker_fn: Optional[
394-
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
394+
Callable[[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None]
395395
] = None
396396
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None
397397

@@ -444,14 +444,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]
444444

445445
def register_post_lookup_tracker_fn(
446446
self,
447-
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
447+
record_fn: Callable[
448+
[KeyedJaggedTensor, torch.Tensor, Optional[nn.Module]], None
449+
],
448450
) -> None:
449451
"""
450452
Register a function to be called after lookup is done. This is used for
451453
tracking the lookup results and optimizer states.
452454
453455
Args:
454-
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
456+
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor,Optional[nn.Module]], None]): A custom record function to be called after lookup is done.
455457
456458
"""
457459
if self.post_lookup_tracker_fn is not None:

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ def compute_and_output_dist(
16711671
):
16721672
embs = lookup(features)
16731673
if self.post_lookup_tracker_fn is not None:
1674-
self.post_lookup_tracker_fn(self, features, embs)
1674+
self.post_lookup_tracker_fn(features, embs, self)
16751675

16761676
with maybe_annotate_embedding_event(
16771677
EmbeddingEvent.OUTPUT_DIST,

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ class ModelDeltaTracker(ABC):
7979

8080
@abstractmethod
8181
def record_lookup(
82-
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
82+
self,
83+
kjt: KeyedJaggedTensor,
84+
states: torch.Tensor,
85+
emb_module: Optional[nn.Module] = None,
8386
) -> None:
8487
"""
8588
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
@@ -233,7 +236,10 @@ def trigger_compaction(self) -> None:
233236
self.curr_compact_index = end_idx
234237

235238
def record_lookup(
236-
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
239+
self,
240+
kjt: KeyedJaggedTensor,
241+
states: torch.Tensor,
242+
emb_module: Optional[nn.Module] = None,
237243
) -> None:
238244
"""
239245
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
@@ -258,12 +264,12 @@ def record_lookup(
258264
self.record_embeddings(kjt, states)
259265
# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
260266
elif self._mode == TrackingMode.MOMENTUM_LAST:
261-
self.record_momentum(emb_module, kjt)
267+
self.record_momentum(none_throws(emb_module), kjt)
262268
elif (
263269
self._mode == TrackingMode.MOMENTUM_DIFF
264270
or self._mode == TrackingMode.ROWWISE_ADAGRAD
265271
):
266-
self.record_rowwise_optim_state(emb_module, kjt)
272+
self.record_rowwise_optim_state(none_throws(emb_module), kjt)
267273
else:
268274
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
269275

0 commit comments

Comments
 (0)