Skip to content

Commit 3350679

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
ModelDeltaTracker optim state init bug fix (#3476)
Summary: Pull Request resolved: #3476 Ensure tracker functions are initialized correctly for Batched Fused Embedding and BathcedFusedPooledEmbedding Differential Revision: D85119191 fbshipit-source-id: 4ec77c8aa8bdc2f1d98e8ce4794d0de8bab97c05
1 parent 0e20667 commit 3350679

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
from torch import nn
21+
from torch.nn.parallel import DistributedDataParallel
2122
from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding
2223

2324
from torchrec.distributed.embedding import ShardedEmbeddingCollection
@@ -169,12 +170,15 @@ def __init__(
169170
self._fqn_to_feature_map: Dict[str, List[str]] = {}
170171
self._fqns_to_skip: Iterable[str] = fqns_to_skip
171172

173+
logger.info(f"Model tracker enabled for {type(model.module)}")
174+
172175
# per_consumer_batch_idx is used to track the batch index for each consumer.
173176
# This is used to retrieve the delta values for a given consumer as well as
174177
# start_ids for compaction window.
175178
self.per_consumer_batch_idx: Dict[str, int] = {
176179
c: -1 for c in (consumers or [self.DEFAULT_CONSUMER])
177180
}
181+
logger.info(f"Model tracker Consumers: {self.per_consumer_batch_idx}")
178182
self.curr_batch_idx: int = 0
179183
self.curr_compact_index: int = 0
180184

@@ -401,6 +405,8 @@ def get_latest(self) -> Dict[str, torch.Tensor]:
401405
for module in self.tracked_modules.values():
402406
# pyre-fixme[29]:
403407
for lookup in module._lookups:
408+
if isinstance(lookup, DistributedDataParallel):
409+
continue
404410
for embs_module in lookup._emb_modules:
405411
assert isinstance(
406412
embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding)
@@ -616,18 +622,22 @@ def _validate_and_init_tracker_fns(self) -> None:
616622
):
617623
# pyre-ignore[29]:
618624
for lookup in module._lookups:
619-
assert isinstance(
625+
if isinstance(
620626
lookup,
621627
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
622-
) and all(
623-
# TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
624-
# pyre-ignore[16]:
625-
emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
626-
# pyre-ignore[16]:
627-
or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM
628-
for emb in lookup._emb_modules
629-
)
630-
lookup.register_optim_state_tracker_fn(self.record_lookup)
628+
):
629+
for emb in lookup._emb_modules:
630+
assert (
631+
isinstance(
632+
emb,
633+
(BatchedFusedEmbedding, BatchedFusedEmbeddingBag),
634+
)
635+
and emb._emb_module.optimizer
636+
# TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
637+
== OptimType.EXACT_ROWWISE_ADAGRAD
638+
or OptimType.PARTIAL_ROWWISE_ADAM
639+
)
640+
lookup.register_optim_state_tracker_fn(self.record_lookup)
631641
else:
632642
raise NotImplementedError(
633643
f"Tracking mode {self._mode} is not supported"

0 commit comments

Comments
 (0)