Skip to content

Commit 8eae1e9

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Update DeltaStore API (#3469)
Summary: Pull Request resolved: #3469 internal General Context: We are in the process of transition to a unified DeltaTracker and this is 2/n diffs representing changes towards the transition. Specific Context: Update DeltaStore APIs to match Memstore APIs for backward compatibility. Differential Revision: D80614586 fbshipit-source-id: c929fb03abf6fe21260a84df6a1b85e846da4b92
1 parent 4984377 commit 8eae1e9

File tree

5 files changed

+15
-15
lines changed

5 files changed

+15
-15
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,14 +466,14 @@ def get_model_tracker(self) -> ModelDeltaTracker:
466466
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
467467
return self.model_delta_tracker
468468

469-
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
469+
def get_unique(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
470470
"""
471471
Returns the delta rows for the given consumer.
472472
"""
473473
assert (
474474
self.model_delta_tracker is not None
475475
), "Model tracker is not initialized. Add ModelTrackerConfig at DistributedModelParallel init."
476-
return self.model_delta_tracker.get_delta(consumer)
476+
return self.model_delta_tracker.get_unique(consumer)
477477

478478
def sparse_grad_parameter_names(
479479
self, destination: Optional[List[str]] = None, prefix: str = ""

torchrec/distributed/model_tracker/delta_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def compact(self, start_idx: int, end_idx: int) -> None:
125125
pass
126126

127127
@abstractmethod
128-
def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
128+
def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
129129
"""
130130
Return all unique/delta ids per table from the Delta Store.
131131
@@ -224,7 +224,7 @@ def compact(self, start_idx: int, end_idx: int) -> None:
224224
)
225225
self.per_fqn_lookups = new_per_fqn_lookups
226226

227-
def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
227+
def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
228228
r"""
229229
Return all unique/delta ids per table from the Delta Store.
230230
"""

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,10 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso
362362
Args:
363363
consumer (str, optional): The consumer to retrieve unique IDs for. If not specified, "default" is used as the default consumer.
364364
"""
365-
per_table_delta_rows = self.get_delta(consumer)
365+
per_table_delta_rows = self.get_unique(consumer)
366366
return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()}
367367

368-
def get_delta(
368+
def get_unique(
369369
self,
370370
consumer: Optional[str] = None,
371371
top_percentage: Optional[float] = 1.0,
@@ -390,7 +390,7 @@ def get_delta(
390390
# and index_start could be equal to index_end, in which case we should not compact again.
391391
if index_start < index_end:
392392
self.compact(index_start, index_end)
393-
tracker_rows = self.store.get_delta(
393+
tracker_rows = self.store.get_unique(
394394
from_idx=self.per_consumer_batch_idx[consumer]
395395
)
396396
self.per_consumer_batch_idx[consumer] = index_end

torchrec/distributed/model_tracker/tests/test_delta_store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,8 +806,8 @@ def test_compact(self, _test_name: str, test_params: CompactTestParams) -> None:
806806
delta_store.compact(
807807
start_idx=test_params.start_idx, end_idx=test_params.end_idx
808808
)
809-
# Verify the result using get_delta method
810-
delta_result = delta_store.get_delta()
809+
# Verify the result using get_unique method
810+
delta_result = delta_store.get_unique()
811811

812812
# compare all fqns in the result
813813
for table_fqn, delta_rows in test_params.expected_delta.items():

torchrec/distributed/model_tracker/tests/test_model_delta_tracker.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def test_fqn_to_feature_names(
447447
),
448448
),
449449
(
450-
"get_delta",
450+
"get_unique",
451451
ModelDeltaTrackerInputTestParams(
452452
embedding_config_type=EmbeddingConfig,
453453
embedding_tables=[
@@ -464,7 +464,7 @@ def test_fqn_to_feature_names(
464464
model_tracker_config=ModelTrackerConfig(),
465465
),
466466
TrackerNotInitOutputTestParams(
467-
dmp_tracker_atter="get_delta",
467+
dmp_tracker_atter="get_unique",
468468
),
469469
),
470470
]
@@ -1843,7 +1843,7 @@ def _test_embedding_mode(
18431843
tracked_out.sum().backward()
18441844
baseline_out.sum().backward()
18451845

1846-
delta_rows = dt.get_delta()
1846+
delta_rows = dt.get_unique()
18471847

18481848
table_fqns = dt.fqn_to_feature_names().keys()
18491849
table_fqns_list = list(table_fqns)
@@ -1964,7 +1964,7 @@ def _test_multiple_get(
19641964
unittest.TestCase().assertTrue(tracked_out.allclose(baseline_out))
19651965
tracked_out.sum().backward()
19661966
baseline_out.sum().backward()
1967-
delta_rows = dt.get_delta()
1967+
delta_rows = dt.get_unique()
19681968

19691969
# Verify that the current batch index is correct
19701970
unittest.TestCase().assertTrue(dt.curr_batch_idx, i + 1)
@@ -2093,7 +2093,7 @@ def _test_duplication_with_momentum(
20932093
dt_model_opt.step()
20942094
baseline_opt.step()
20952095

2096-
delta_rows = dt.get_delta()
2096+
delta_rows = dt.get_unique()
20972097
for table_fqn in table_fqns_list:
20982098
ids = delta_rows[table_fqn].ids
20992099
states = none_throws(delta_rows[table_fqn].states)
@@ -2162,7 +2162,7 @@ def _test_duplication_with_rowwise_adagrad(
21622162

21632163
end_momentums = tbe.split_optimizer_states()[0][0].detach().clone()
21642164

2165-
delta_rows = dt.get_delta()
2165+
delta_rows = dt.get_unique()
21662166
table_fqn = table_fqns_list[0]
21672167

21682168
ids = delta_rows[table_fqn].ids

0 commit comments

Comments
 (0)