Skip to content

Commit 4984377

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Update DeltaStore to be Generic (#3468)
Summary: Pull Request resolved: #3468 Make DeltaStore generic to allow use case specific custom implementations internal General Context: We are in the process of transition to a unified DeltaTracker and this is 1/n diffs representing changes towards the transition. Specific Context: DeltaTracker utilizes Memstore to preserve and compact lookups extracted during embedding lookups. As part of transitioning to a common DeltaTracker, we are adding a generic DeltaStore. Memstore will extend from Generic DeltaStore, allowing both MRS and OSS DeltaTrackers to be easily integrated into training frameworks. Differential Revision: D80614364 fbshipit-source-id: 9ef57943bfa4ea1ff630d14d2ce7805775b6505f
1 parent 5e9763d commit 4984377

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
# pyre-strict
10+
from abc import ABC, abstractmethod
1011
from bisect import bisect_left
1112
from typing import Dict, List, Optional
1213

@@ -67,34 +68,106 @@ def _compute_unique_rows(
6768
return DeltaRows(ids=unique_ids, states=unique_states)
6869

6970

70-
class DeltaStore:
71+
class DeltaStore(ABC):
7172
"""
72-
DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across
73-
various batches during training, designed to be used with TorchRecs ModelDeltaTracker.
74-
It maintains a CUDA in-memory representation of requested ids and embeddings/states,
73+
DeltaStore is an abstract base class that defines the interface for storing and managing
74+
local delta (row) updates for embeddings/states across various batches during training.
75+
76+
Implementations should maintain a representation of requested ids and embeddings/states,
7577
providing a way to compact and get delta updates for each embedding table.
7678
7779
The class supports different embedding update modes (NONE, FIRST, LAST) to determine
7880
how to handle duplicate ids when compacting or retrieving embeddings.
81+
"""
82+
83+
@abstractmethod
84+
def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
85+
pass
86+
87+
@abstractmethod
88+
def append(
89+
self,
90+
batch_idx: int,
91+
fqn: str,
92+
ids: torch.Tensor,
93+
states: Optional[torch.Tensor],
94+
) -> None:
95+
"""
96+
Append a batch of ids and states to the store for a specific table.
97+
98+
Args:
99+
batch_idx: The batch index
100+
table_fqn: The fully qualified name of the table
101+
ids: The tensor of ids to append
102+
states: Optional tensor of states to append
103+
"""
104+
pass
105+
106+
@abstractmethod
107+
def delete(self, up_to_idx: Optional[int] = None) -> None:
108+
"""
109+
Delete all idx from the store up to `up_to_idx`
110+
111+
Args:
112+
up_to_idx: Optional index up to which to delete lookups
113+
"""
114+
pass
79115

116+
@abstractmethod
117+
def compact(self, start_idx: int, end_idx: int) -> None:
118+
"""
119+
Compact (ids, embeddings) in batch index range from start_idx to end_idx.
120+
121+
Args:
122+
start_idx: The starting batch index
123+
end_idx: The ending batch index
124+
"""
125+
pass
126+
127+
@abstractmethod
128+
def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
129+
"""
130+
Return all unique/delta ids per table from the Delta Store.
131+
132+
Args:
133+
from_idx: The batch index from which to get deltas
134+
135+
Returns:
136+
A dictionary mapping table FQNs to their delta rows
137+
"""
138+
pass
139+
140+
141+
class DeltaStoreTrec(DeltaStore):
142+
"""
143+
DeltaStoreTrec is a concrete implementation of DeltaStore that stores and manages
144+
local delta (row) updates for embeddings/states across various batches during training,
145+
designed to be used with TorchRecs ModelDeltaTracker.
146+
147+
It maintains a CUDA in-memory representation of requested ids and embeddings/states,
148+
providing a way to compact and get delta updates for each embedding table.
149+
150+
The class supports different embedding update modes (NONE, FIRST, LAST) to determine
151+
how to handle duplicate ids when compacting or retrieving embeddings.
80152
"""
81153

82154
def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None:
155+
super().__init__(embdUpdateMode)
83156
self.embdUpdateMode = embdUpdateMode
84157
self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {}
85158

86159
def append(
87160
self,
88161
batch_idx: int,
89-
table_fqn: str,
162+
fqn: str,
90163
ids: torch.Tensor,
91164
states: Optional[torch.Tensor],
92165
) -> None:
93-
table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, [])
166+
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
94167
table_fqn_lookup.append(
95168
IndexedLookup(batch_idx=batch_idx, ids=ids, states=states)
96169
)
97-
self.per_fqn_lookups[table_fqn] = table_fqn_lookup
170+
self.per_fqn_lookups[fqn] = table_fqn_lookup
98171

99172
def delete(self, up_to_idx: Optional[int] = None) -> None:
100173
"""

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
GroupedPooledEmbeddingsLookup,
2727
)
2828
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
29-
from torchrec.distributed.model_tracker.delta_store import DeltaStore
29+
from torchrec.distributed.model_tracker.delta_store import DeltaStoreTrec
3030
from torchrec.distributed.model_tracker.types import (
3131
DeltaRows,
3232
EmbdUpdateMode,
@@ -122,7 +122,7 @@ def __init__(
122122
# Validate is the mode is supported for the given module and initialize tracker functions
123123
self._validate_and_init_tracker_fns()
124124

125-
self.store: DeltaStore = DeltaStore(UPDATE_MODE_MAP[self._mode])
125+
self.store: DeltaStoreTrec = DeltaStoreTrec(UPDATE_MODE_MAP[self._mode])
126126

127127
# Mapping feature name to corresponding FQNs. This is used for retrieving
128128
# the FQN associated with a given feature name in record_lookup().
@@ -222,7 +222,7 @@ def record_ids(self, kjt: KeyedJaggedTensor) -> None:
222222
for table_fqn, ids_list in per_table_ids.items():
223223
self.store.append(
224224
batch_idx=self.curr_batch_idx,
225-
table_fqn=table_fqn,
225+
fqn=table_fqn,
226226
ids=torch.cat(ids_list),
227227
states=None,
228228
)
@@ -262,7 +262,7 @@ def record_embeddings(
262262
for table_fqn, ids_list in per_table_ids.items():
263263
self.store.append(
264264
batch_idx=self.curr_batch_idx,
265-
table_fqn=table_fqn,
265+
fqn=table_fqn,
266266
ids=torch.cat(ids_list),
267267
states=torch.cat(per_table_emb[table_fqn]),
268268
)
@@ -295,7 +295,7 @@ def record_momentum(
295295
per_key_states = states[offsets[i] : offsets[i + 1]]
296296
self.store.append(
297297
batch_idx=self.curr_batch_idx,
298-
table_fqn=fqn,
298+
fqn=fqn,
299299
ids=kjt[key].values(),
300300
states=per_key_states,
301301
)
@@ -323,7 +323,7 @@ def record_rowwise_optim_state(
323323
per_key_states = states[offsets[i] : offsets[i + 1]]
324324
self.store.append(
325325
batch_idx=self.curr_batch_idx,
326-
table_fqn=fqn,
326+
fqn=fqn,
327327
ids=kjt[key].values(),
328328
states=per_key_states,
329329
)

torchrec/distributed/model_tracker/tests/test_delta_store.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from parameterized import parameterized
1616
from torchrec.distributed.model_tracker.delta_store import (
1717
_compute_unique_rows,
18-
DeltaStore,
18+
DeltaStoreTrec,
1919
)
2020
from torchrec.distributed.model_tracker.types import (
2121
DeltaRows,
@@ -24,7 +24,7 @@
2424
)
2525

2626

27-
class DeltaStoreTest(unittest.TestCase):
27+
class DeltaStoreTrecTest(unittest.TestCase):
2828
# pyre-fixme[2]: Parameter must be annotated.
2929
def __init__(self, methodName="runTest") -> None:
3030
super().__init__(methodName)
@@ -188,12 +188,12 @@ class AppendDeleteTestParams:
188188
def test_append_and_delete(
189189
self, _test_name: str, test_params: AppendDeleteTestParams
190190
) -> None:
191-
delta_store = DeltaStore()
191+
delta_store = DeltaStoreTrec()
192192
for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items():
193193
for lookup in lookup_list:
194194
delta_store.append(
195195
batch_idx=lookup.batch_idx,
196-
table_fqn=table_fqn,
196+
fqn=table_fqn,
197197
ids=lookup.ids,
198198
states=lookup.states,
199199
)
@@ -783,15 +783,15 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None:
783783
"""
784784
Test the compact method of DeltaStore.
785785
"""
786-
# Create a DeltaStore with the specified embdUpdateMode
787-
delta_store = DeltaStore(embdUpdateMode=test_params.embdUpdateMode)
786+
# Create a DeltaStoreTrec with the specified embdUpdateMode
787+
delta_store = DeltaStoreTrec(embdUpdateMode=test_params.embdUpdateMode)
788788

789789
# Populate the DeltaStore with the test lookups
790790
for table_fqn, lookup_list in test_params.table_fqn_to_lookups.items():
791791
for lookup in lookup_list:
792792
delta_store.append(
793793
batch_idx=lookup.batch_idx,
794-
table_fqn=table_fqn,
794+
fqn=table_fqn,
795795
ids=lookup.ids,
796796
states=lookup.states,
797797
)

0 commit comments

Comments
 (0)