Skip to content

Commit c2414ac

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Renamed types to account for broader use cases (#3473)
Summary: Pull Request resolved: #3473 internal General Context: We are in the process of transition to a unified DeltaTracker and this is 5/n diffs representing changes towards the transition. Specific Context: Consolidating all types to a common place within torchRec Differential Revision: D80615183 fbshipit-source-id: c730f312a1314c138f54d21b80b4ff12830f8e8f
1 parent 521f159 commit c2414ac

File tree

6 files changed

+95
-95
lines changed

6 files changed

+95
-95
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from torch.nn.parallel import DistributedDataParallel
3131
from torchrec.distributed.comm import get_local_size
3232
from torchrec.distributed.model_tracker.model_delta_tracker import ModelDeltaTrackerTrec
33-
from torchrec.distributed.model_tracker.types import DeltaRows, ModelTrackerConfig
33+
from torchrec.distributed.model_tracker.types import ModelTrackerConfig, UniqueRows
3434

3535
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
3636
from torchrec.distributed.sharding_plan import get_default_sharders
@@ -466,7 +466,7 @@ def get_model_tracker(self) -> ModelDeltaTrackerTrec:
466466
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
467467
return self.model_delta_tracker
468468

469-
def get_unique(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
469+
def get_unique(self, consumer: Optional[str] = None) -> Dict[str, UniqueRows]:
470470
"""
471471
Returns the delta rows for the given consumer.
472472
"""

torchrec/distributed/model_tracker/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
SUPPORTED_MODULES, # noqa
2828
)
2929
from torchrec.distributed.model_tracker.types import (
30-
DeltaRows, # noqa
31-
EmbdUpdateMode, # noqa
3230
IndexedLookup, # noqa
3331
ModelTrackerConfig, # noqa
3432
TrackingMode, # noqa
33+
UniqueRows, # noqa
34+
UpdateMode, # noqa
3535
)

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,34 @@
1313

1414
import torch
1515
from torchrec.distributed.model_tracker.types import (
16-
DeltaRows,
17-
EmbdUpdateMode,
1816
IndexedLookup,
17+
UniqueRows,
18+
UpdateMode,
1919
)
2020
from torchrec.distributed.utils import none_throws
2121

2222

2323
def _compute_unique_rows(
2424
ids: List[torch.Tensor],
2525
states: Optional[List[torch.Tensor]],
26-
mode: EmbdUpdateMode,
27-
) -> DeltaRows:
26+
mode: UpdateMode,
27+
) -> UniqueRows:
2828
r"""
2929
To calculate unique ids and embeddings
3030
"""
31-
if mode == EmbdUpdateMode.NONE:
32-
assert states is None, f"{mode=} == EmbdUpdateMode.NONE but received embeddings"
31+
if mode == UpdateMode.NONE:
32+
assert states is None, f"{mode=} == UpdateMode.NONE but received embeddings"
3333
unique_ids = torch.cat(ids).unique(return_inverse=False)
34-
return DeltaRows(ids=unique_ids, states=None)
34+
return UniqueRows(ids=unique_ids, states=None)
3535
else:
3636
assert (
3737
states is not None
38-
), f"{mode=} != EmbdUpdateMode.NONE but received no embeddings"
38+
), f"{mode=} != UpdateMode.NONE but received no embeddings"
3939

4040
cat_ids = torch.cat(ids)
4141
cat_states = torch.cat(states)
4242

43-
if mode == EmbdUpdateMode.LAST:
43+
if mode == UpdateMode.LAST:
4444
cat_ids = cat_ids.flip(dims=[0])
4545
cat_states = cat_states.flip(dims=[0])
4646

@@ -65,7 +65,7 @@ def _compute_unique_rows(
6565

6666
# Use first occurrence indices to select corresponding embedding row.
6767
unique_states = cat_states[first_occurrence]
68-
return DeltaRows(ids=unique_ids, states=unique_states)
68+
return UniqueRows(ids=unique_ids, states=unique_states)
6969

7070

7171
class DeltaStore(ABC):
@@ -81,7 +81,7 @@ class DeltaStore(ABC):
8181
"""
8282

8383
@abstractmethod
84-
def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
84+
def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None:
8585
pass
8686

8787
@abstractmethod
@@ -125,7 +125,7 @@ def compact(self, start_idx: int, end_idx: int) -> None:
125125
pass
126126

127127
@abstractmethod
128-
def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
128+
def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
129129
"""
130130
Return all unique/delta ids per table from the Delta Store.
131131
@@ -151,9 +151,9 @@ class DeltaStoreTrec(DeltaStore):
151151
how to handle duplicate ids when compacting or retrieving embeddings.
152152
"""
153153

154-
def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
155-
super().__init__(embdUpdateMode)
156-
self.embdUpdateMode = embdUpdateMode
154+
def __init__(self, updateMode: UpdateMode = UpdateMode.NONE) -> None:
155+
super().__init__(updateMode)
156+
self.updateMode = updateMode
157157
self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
158158

159159
def append(
@@ -205,11 +205,11 @@ def compact(self, start_idx: int, end_idx: int) -> None:
205205
ids = [lookup.ids for lookup in lookups_to_compact]
206206
states = (
207207
[none_throws(lookup.states) for lookup in lookups_to_compact]
208-
if self.embdUpdateMode != EmbdUpdateMode.NONE
208+
if self.updateMode != UpdateMode.NONE
209209
else None
210210
)
211211
delta_rows = _compute_unique_rows(
212-
ids=ids, states=states, mode=self.embdUpdateMode
212+
ids=ids, states=states, mode=self.updateMode
213213
)
214214
new_per_fqn_lookups[table_fqn] = (
215215
lookups[:index_l]
@@ -224,12 +224,12 @@ def compact(self, start_idx: int, end_idx: int) -> None:
224224
)
225225
self.per_fqn_lookups = new_per_fqn_lookups
226226

227-
def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
227+
def get_unique(self, from_idx: int = 0) -> Dict[str, UniqueRows]:
228228
r"""
229229
Return all unique/delta ids per table from the Delta Store.
230230
"""
231231

232-
delta_per_table_fqn: Dict[str, DeltaRows] = {}
232+
delta_per_table_fqn: Dict[str, UniqueRows] = {}
233233
for table_fqn, lookups in self.per_fqn_lookups.items():
234234
compact_ids = [
235235
lookup.ids for lookup in lookups if lookup.batch_idx >= from_idx
@@ -240,11 +240,11 @@ def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
240240
for lookup in lookups
241241
if lookup.batch_idx >= from_idx
242242
]
243-
if self.embdUpdateMode != EmbdUpdateMode.NONE
243+
if self.updateMode != UpdateMode.NONE
244244
else None
245245
)
246246

247247
delta_per_table_fqn[table_fqn] = _compute_unique_rows(
248-
ids=compact_ids, states=compact_states, mode=self.embdUpdateMode
248+
ids=compact_ids, states=compact_states, mode=self.updateMode
249249
)
250250
return delta_per_table_fqn

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,36 +29,36 @@
2929
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
3030
from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec
3131
from torchrec.distributed.model_tracker.types import (
32-
DeltaRows,
33-
EmbdUpdateMode,
3432
TrackingMode,
33+
UniqueRows,
34+
UpdateMode,
3535
)
3636
from torchrec.distributed.utils import none_throws
3737

3838
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3939

40-
UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = {
40+
UPDATE_MODE_MAP: Dict[TrackingMode, UpdateMode] = {
4141
# Only IDs are tracked, no additional state is stored.
42-
TrackingMode.ID_ONLY: EmbdUpdateMode.NONE,
43-
# TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that
42+
TrackingMode.ID_ONLY: UpdateMode.NONE,
43+
# TrackingMode.EMBEDDING utilizes UpdateMode.FIRST to ensure that
4444
# the earliest embedding values are stored since the last checkpoint
4545
# or snapshot. This mode is used for computing topk delta rows, which
4646
# is currently achieved by running (new_emb - old_emb).norm().topk().
47-
TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST,
48-
# TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that
47+
TrackingMode.EMBEDDING: UpdateMode.FIRST,
48+
# TrackingMode.MOMENTUM utilizes UpdateMode.LAST to ensure that
4949
# the most recent momentum values—capturing the accumulated gradient
5050
# direction and magnitude—are stored since the last batch.
5151
# This mode supports approximate top-k delta-row selection, can be
5252
# obtained by running momentum.norm().topk().
53-
TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST,
53+
TrackingMode.MOMENTUM_LAST: UpdateMode.LAST,
5454
# MOMENTUM_DIFF keeps a running sum of the square of the gradients per row.
5555
# Within each publishing interval, we track the starting value of this running
5656
# sum on all used rows and then do a lookup when ``get_delta`` is called to query
5757
# the latest sum. Then we can compute the delta of the two values and return them
5858
# together with the row ids.
59-
TrackingMode.MOMENTUM_DIFF: EmbdUpdateMode.FIRST,
59+
TrackingMode.MOMENTUM_DIFF: UpdateMode.FIRST,
6060
# The same as MOMENTUM_DIFF. Adding for backward compatibility.
61-
TrackingMode.ROWWISE_ADAGRAD: EmbdUpdateMode.FIRST,
61+
TrackingMode.ROWWISE_ADAGRAD: UpdateMode.FIRST,
6262
}
6363

6464
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -111,7 +111,7 @@ def get_unique(
111111
top_percentage: Optional[float] = 1.0,
112112
per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None,
113113
sorted_by_indices: Optional[bool] = True,
114-
) -> Dict[str, DeltaRows]:
114+
) -> Dict[str, UniqueRows]:
115115
"""
116116
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature.
117117
@@ -439,7 +439,7 @@ def get_unique(
439439
top_percentage: Optional[float] = 1.0,
440440
per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None,
441441
sorted_by_indices: Optional[bool] = True,
442-
) -> Dict[str, DeltaRows]:
442+
) -> Dict[str, UniqueRows]:
443443
"""
444444
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
445445
@@ -471,8 +471,7 @@ def get_unique(
471471
assert (
472472
fqn in square_sum_map
473473
), f"{fqn} not found in {square_sum_map.keys()}"
474-
# pyre-fixme[58]: `-` is not supported for operand types `Tensor`
475-
# and `Optional[Tensor]`.
474+
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and `Optional[Tensor]`.
476475
rows.states = square_sum_map[fqn][rows.ids] - rows.states
477476

478477
return tracker_rows

0 commit comments

Comments
 (0)