Skip to content

Commit 3ede596

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Update ModelDeltaTracker to be Generic (#3470)
Summary: Pull Request resolved: #3470 Make ModelDeltaTracker generic to allow use case specific custom implementations internal General Context: We are in the process of transition to a unified DeltaTracker and this is 3/n diffs representing changes towards the transition. Specific Context: DeltaTracker implements primitives to allow tracking of embedding ids and states to optimize checkpointing and embedding freshness. As part of transitioning to a common DeltaTracker, we are adding a generic ModelDeltaTracker. MRS DeltaTracker will extend from Generic ModelDeltaTracker. Differential Revision: D80614689 fbshipit-source-id: 07126b50f4c2819f7492ee0fdb549533f102e970
1 parent 8eae1e9 commit 3ede596

File tree

3 files changed

+74
-12
lines changed

3 files changed

+74
-12
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torch.nn.modules.module import _IncompatibleKeys
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
32-
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTracker
32+
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
3333
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig
3434

3535
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
@@ -293,7 +293,7 @@ def __init__(
293293
if init_data_parallel:
294294
self.init_data_parallel()
295295

296-
self.model_delta_tracker: Optional[ModelDeltaTracker] = (
296+
self.model_delta_tracker: Optional[ModelDeltaTrackerTrec] = (
297297
self._init_delta_tracker(model_tracker_config, self._dmp_wrapped_module)
298298
if model_tracker_config is not None
299299
else None
@@ -369,9 +369,9 @@ def _init_dmp(self, module: nn.Module) -> nn.Module:
369369

370370
def _init_delta_tracker(
371371
self, model_tracker_config: ModelTrackerConfig, module: nn.Module
372-
) -> ModelDeltaTracker:
372+
) -> ModelDeltaTrackerTrec:
373373
# Init delta tracker if config is provided
374-
return ModelDeltaTracker(
374+
return ModelDeltaTrackerTrec(
375375
model=module,
376376
consumers=model_tracker_config.consumers,
377377
delete_on_read=model_tracker_config.delete_on_read,
@@ -456,7 +456,7 @@ def init_parameters(module: nn.Module) -> None:
456456

457457
module.apply(init_parameters)
458458

459-
def get_model_tracker(self) -> ModelDeltaTracker:
459+
def get_model_tracker(self) -> ModelDeltaTrackerTrec:
460460
"""
461461
Returns the model tracker if it exists.
462462
"""

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99
import logging as logger
10+
from abc import ABC, abstractmethod
1011
from collections import Counter, OrderedDict
1112
from typing import Dict, Iterable, List, Optional, Tuple
1213

@@ -64,10 +65,73 @@
6465
SUPPORTED_MODULES = (ShardedEmbeddingCollection, ShardedEmbeddingBagCollection)
6566

6667

67-
class ModelDeltaTracker:
68+
class ModelDeltaTracker(ABC):
6869
r"""
6970
70-
ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
71+
Abstract base class for ModelDeltaTracker that provides a way to track and retrieve unique IDs for supported modules,
72+
along with optional support for tracking corresponding embeddings or states. This is useful for identifying and
73+
retrieving the latest delta or unique rows for a given model, which can help compute topk or to stream updated
74+
embeddings from predictors to trainers during online training.
75+
76+
"""
77+
78+
DEFAULT_CONSUMER: str = "default"
79+
80+
@abstractmethod
81+
def record_lookup(
82+
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
83+
) -> None:
84+
"""
85+
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
86+
87+
Args:
88+
emb_module (nn.Module): The embedding module in which the lookup was performed.
89+
kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
90+
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
91+
"""
92+
pass
93+
94+
@abstractmethod
95+
def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
96+
"""
97+
Return a dictionary of hit local IDs for each sparse feature.
98+
99+
Args:
100+
consumer (str, optional): The consumer to retrieve unique IDs for.
101+
"""
102+
pass
103+
104+
@abstractmethod
105+
def get_unique(
106+
self,
107+
consumer: Optional[str] = None,
108+
top_percentage: Optional[float] = 1.0,
109+
per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None,
110+
sorted_by_indices: Optional[bool] = True,
111+
) -> Dict[str, DeltaRows]:
112+
"""
113+
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature.
114+
115+
Args:
116+
consumer (str, optional): The consumer to retrieve delta values for.
117+
"""
118+
pass
119+
120+
@abstractmethod
121+
def clear(self, consumer: Optional[str] = None) -> None:
122+
"""
123+
Clear tracked IDs for a given consumer.
124+
125+
Args:
126+
consumer (str, optional): The consumer to clear IDs/States for.
127+
"""
128+
pass
129+
130+
131+
class ModelDeltaTrackerTrec(ModelDeltaTracker):
132+
r"""
133+
134+
ModelDeltaTrackerTrec provides a way to track and retrieve unique IDs for supported modules, along with optional support
71135
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
72136
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
73137
online training. Unique IDs or states can be retrieved by calling the get_delta() method.
@@ -85,8 +149,6 @@ class ModelDeltaTracker:
85149
86150
"""
87151

88-
DEFAULT_CONSUMER: str = "default"
89-
90152
def __init__(
91153
self,
92154
model: nn.Module,
@@ -354,7 +416,7 @@ def get_latest(self) -> Dict[str, torch.Tensor]:
354416

355417
return ret
356418

357-
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
419+
def get_unique_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
358420
"""
359421
Return a dictionary of hit local IDs for each sparse feature. Ids are
360422
first keyed by submodule FQN.

torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,7 +1739,7 @@ def _test_id_mode(
17391739
tracked_out.sum().backward()
17401740
baseline_out.sum().backward()
17411741

1742-
delta_ids = dt.get_delta_ids()
1742+
delta_ids = dt.get_unique_ids()
17431743

17441744
table_fqns = dt.fqn_to_feature_names().keys()
17451745

@@ -2035,7 +2035,7 @@ def _test_multiple_consumer(
20352035
unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out))
20362036
tracked_out.sum().backward()
20372037
baseline_out.sum().backward()
2038-
delta_rows = dt.get_delta_ids(consumer=consumer)
2038+
delta_rows = dt.get_unique_ids(consumer=consumer)
20392039

20402040
# Verify that the current batch index is correct
20412041
unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1)

0 commit comments

Comments
 (0)