|
18 | 18 | ) |
19 | 19 |
|
20 | 20 | from torch import nn |
| 21 | +from torch.nn.parallel import DistributedDataParallel |
21 | 22 | from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding |
22 | 23 |
|
23 | 24 | from torchrec.distributed.embedding import ShardedEmbeddingCollection |
@@ -169,12 +170,15 @@ def __init__( |
169 | 170 | self._fqn_to_feature_map: Dict[str, List[str]] = {} |
170 | 171 | self._fqns_to_skip: Iterable[str] = fqns_to_skip |
171 | 172 |
|
| 173 | + logger.info(f"Model tracker enabled for {type(model.module)}") |
| 174 | + |
172 | 175 | # per_consumer_batch_idx is used to track the batch index for each consumer. |
173 | 176 | # This is used to retrieve the delta values for a given consumer as well as |
174 | 177 | # start_ids for compaction window. |
175 | 178 | self.per_consumer_batch_idx: Dict[str, int] = { |
176 | 179 | c: -1 for c in (consumers or [self.DEFAULT_CONSUMER]) |
177 | 180 | } |
| 181 | + logger.info(f"Model tracker Consumers: {self.per_consumer_batch_idx}") |
178 | 182 | self.curr_batch_idx: int = 0 |
179 | 183 | self.curr_compact_index: int = 0 |
180 | 184 |
|
@@ -401,6 +405,8 @@ def get_latest(self) -> Dict[str, torch.Tensor]: |
401 | 405 | for module in self.tracked_modules.values(): |
402 | 406 | # pyre-fixme[29]: |
403 | 407 | for lookup in module._lookups: |
| 408 | + if isinstance(lookup, DistributedDataParallel): |
| 409 | + continue |
404 | 410 | for embs_module in lookup._emb_modules: |
405 | 411 | assert isinstance( |
406 | 412 | embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding) |
@@ -616,18 +622,22 @@ def _validate_and_init_tracker_fns(self) -> None: |
616 | 622 | ): |
617 | 623 | # pyre-ignore[29]: |
618 | 624 | for lookup in module._lookups: |
619 | | - assert isinstance( |
| 625 | + if isinstance( |
620 | 626 | lookup, |
621 | 627 | (GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup), |
622 | | - ) and all( |
623 | | - # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD |
624 | | - # pyre-ignore[16]: |
625 | | - emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD |
626 | | - # pyre-ignore[16]: |
627 | | - or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM |
628 | | - for emb in lookup._emb_modules |
629 | | - ) |
630 | | - lookup.register_optim_state_tracker_fn(self.record_lookup) |
| 628 | + ): |
| 629 | + for emb in lookup._emb_modules: |
| 630 | + assert ( |
| 631 | + isinstance( |
| 632 | + emb, |
| 633 | + (BatchedFusedEmbedding, BatchedFusedEmbeddingBag), |
| 634 | + ) |
| 635 | + and emb._emb_module.optimizer |
| 636 | + # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD |
| 637 | + == OptimType.EXACT_ROWWISE_ADAGRAD |
| 638 | + or OptimType.PARTIAL_ROWWISE_ADAM |
| 639 | + ) |
| 640 | + lookup.register_optim_state_tracker_fn(self.record_lookup) |
631 | 641 | else: |
632 | 642 | raise NotImplementedError( |
633 | 643 | f"Tracking mode {self._mode} is not supported" |
|
0 commit comments