Skip to content

Commit 397a6f6

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Adding support for tracking optimizers states in Model Delta Tracker. (#3143)
Summary: Pull Request resolved: #3143 X-link: #3143 ### Overview This diff adds support for tracking optimizer states in the Model Delta Tracker system. It introduces a new tracking mode called `MOMENTUM_LAST` that enables tracking of momentum values from optimizers to support approximate top-k delta-row selection. ### Key Changes #### 1. Optimizer State Tracking Support * To support tracking of optimizer states I have added `optim_state_tracker_fn` attribute to `GroupedEmbeddingsLookup` and `GroupedPooledEmbeddingsLookup` classes responsible for traversing over the BatchedFused modules. * Implemented `register_optim_state_tracker_fn()` method in both classes to register the trackable callable * Tracking calls are invoked after each lookup operation. #### 2. Model Delta Tracker Changes * Added `record_momentum()` method to track momentum values from optimizer states and its support in record_lookup function. * Added validation and optim tracker function logic to support the new `MOMENTUM_LAST` mode #### 3. New Tracking Mode * Added `TrackingMode.MOMENTUM_LAST` to [`**types.py**`](command:code-compose.open?%5B%22%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/fbcode/torchrec/distributed/model_tracker/types.py") * Maps to `EmbdUpdateMode.LAST` to capture the most recent momentum values Differential Revision: D76868111 fbshipit-source-id: bde3d4be8d3df7fe5b2f284a262c50a5313c1dc0
1 parent 80dbb88 commit 397a6f6

File tree

7 files changed

+279
-18
lines changed

7 files changed

+279
-18
lines changed

torchrec/distributed/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@ def compute_and_output_dist(
15871587
):
15881588
embs = lookup(features)
15891589
if self.post_lookup_tracker_fn is not None:
1590-
self.post_lookup_tracker_fn(features, embs)
1590+
self.post_lookup_tracker_fn(self, features, embs)
15911591

15921592
with maybe_annotate_embedding_event(
15931593
EmbeddingEvent.OUTPUT_DIST, self._module_fqn, sharding_type

torchrec/distributed/embedding_lookup.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
from abc import ABC
1212
from collections import OrderedDict
13-
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
13+
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
1414

1515
import torch
1616
import torch.distributed as dist
@@ -208,6 +208,10 @@ def __init__(
208208
)
209209

210210
self.grouped_configs = grouped_configs
211+
# Model tracker function to tracker optimizer state
212+
self.optim_state_tracker_fn: Optional[
213+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
214+
] = None
211215

212216
def _create_embedding_kernel(
213217
self,
@@ -315,7 +319,13 @@ def forward(
315319
self._feature_splits,
316320
)
317321
for emb_op, features in zip(self._emb_modules, features_by_group):
318-
embeddings.append(emb_op(features).view(-1))
322+
lookup = emb_op(features).view(-1)
323+
embeddings.append(lookup)
324+
325+
# Model tracker optimizer state function, will only be set called
326+
# when model tracker is configured to track optimizer state
327+
if self.optim_state_tracker_fn is not None:
328+
self.optim_state_tracker_fn(emb_op, features, lookup)
319329

320330
return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)
321331

@@ -420,6 +430,19 @@ def purge(self) -> None:
420430
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
421431
emb_module.purge()
422432

433+
def register_optim_state_tracker_fn(
434+
self,
435+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
436+
) -> None:
437+
"""
438+
Model tracker function to tracker optimizer state
439+
440+
Args:
441+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
442+
443+
"""
444+
self.optim_state_tracker_fn = record_fn
445+
423446

424447
class GroupedEmbeddingsUpdate(BaseEmbeddingUpdate[KeyedJaggedTensor]):
425448
"""
@@ -519,6 +542,10 @@ def __init__(
519542
if scale_weight_gradients and get_gradient_division()
520543
else 1
521544
)
545+
# Model tracker function to tracker optimizer state
546+
self.optim_state_tracker_fn: Optional[
547+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
548+
] = None
522549

523550
def _create_embedding_kernel(
524551
self,
@@ -678,7 +705,12 @@ def forward(
678705
features._weights, self._scale_gradient_factor
679706
)
680707

681-
embeddings.append(emb_op(features))
708+
lookup = emb_op(features)
709+
embeddings.append(lookup)
710+
# Model tracker optimizer state function, will only be set called
711+
# when model tracker is configured to track optimizer state
712+
if self.optim_state_tracker_fn is not None:
713+
self.optim_state_tracker_fn(emb_op, features, lookup)
682714

683715
if features.variable_stride_per_key() and len(self._emb_modules) > 1:
684716
stride_per_rank_per_key = list(
@@ -811,6 +843,19 @@ def purge(self) -> None:
811843
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
812844
emb_module.purge()
813845

846+
def register_optim_state_tracker_fn(
847+
self,
848+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
849+
) -> None:
850+
"""
851+
Model tracker function to tracker optimizer state
852+
853+
Args:
854+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
855+
856+
"""
857+
self.optim_state_tracker_fn = record_fn
858+
814859

815860
class MetaInferGroupedEmbeddingsLookup(
816861
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn

torchrec/distributed/embedding_types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def __init__(
391391
self._lookups: List[nn.Module] = []
392392
self._output_dists: List[nn.Module] = []
393393
self.post_lookup_tracker_fn: Optional[
394-
Callable[[KeyedJaggedTensor, torch.Tensor], None]
394+
Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]
395395
] = None
396396
self.post_odist_tracker_fn: Optional[Callable[..., None]] = None
397397

@@ -444,14 +444,14 @@ def train(self, mode: bool = True): # pyre-ignore[3]
444444

445445
def register_post_lookup_tracker_fn(
446446
self,
447-
record_fn: Callable[[KeyedJaggedTensor, torch.Tensor], None],
447+
record_fn: Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None],
448448
) -> None:
449449
"""
450450
Register a function to be called after lookup is done. This is used for
451451
tracking the lookup results and optimizer states.
452452
453453
Args:
454-
record_fn (Callable[[KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
454+
record_fn (Callable[[nn.Module, KeyedJaggedTensor, torch.Tensor], None]): A custom record function to be called after lookup is done.
455455
456456
"""
457457
if self.post_lookup_tracker_fn is not None:

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ def compute_and_output_dist(
16711671
):
16721672
embs = lookup(features)
16731673
if self.post_lookup_tracker_fn is not None:
1674-
self.post_lookup_tracker_fn(features, embs)
1674+
self.post_lookup_tracker_fn(self, features, embs)
16751675

16761676
with maybe_annotate_embedding_event(
16771677
EmbeddingEvent.OUTPUT_DIST,

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
import torch
1414

1515
from torch import nn
16+
1617
from torchrec.distributed.embedding import ShardedEmbeddingCollection
18+
from torchrec.distributed.embedding_lookup import (
19+
GroupedEmbeddingsLookup,
20+
GroupedPooledEmbeddingsLookup,
21+
)
1722
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
1823
from torchrec.distributed.model_tracker.delta_store import DeltaStore
1924
from torchrec.distributed.model_tracker.types import (
@@ -27,9 +32,16 @@
2732
# Only IDs are tracked, no additional state is stored.
2833
TrackingMode.ID_ONLY: EmbdUpdateMode.NONE,
2934
# TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that
30-
# the earliest embedding values are stored since the last checkpoint or snapshot.
31-
# This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk().
35+
# the earliest embedding values are stored since the last checkpoint
36+
# or snapshot. This mode is used for computing topk delta rows, which
37+
# is currently achieved by running (new_emb - old_emb).norm().topk().
3238
TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST,
39+
# TrackingMode.MOMENTUM utilizes EmbdUpdateMode.LAST to ensure that
40+
# the most recent momentum values—capturing the accumulated gradient
41+
# direction and magnitude—are stored since the last batch.
42+
# This mode supports approximate top-k delta-row selection, can be
43+
# obtained by running momentum.norm().topk().
44+
TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST,
3345
}
3446

3547
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -141,7 +153,9 @@ def trigger_compaction(self) -> None:
141153
# Update the current compact index to the end index to avoid duplicate compaction.
142154
self.curr_compact_index = end_idx
143155

144-
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
156+
def record_lookup(
157+
self, emb_module: nn.Module, kjt: KeyedJaggedTensor, states: torch.Tensor
158+
) -> None:
145159
"""
146160
Records the IDs from a given KeyedJaggedTensor and their corresponding embeddings/parameter states.
147161
@@ -152,6 +166,7 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
152166
(in ID_ONLY mode) or both IDs and their corresponding embeddings (in EMBEDDING mode).
153167
154168
Args:
169+
emb_module (nn.Module): The embedding module in which the lookup was performed.
155170
kjt (KeyedJaggedTensor): The KeyedJaggedTensor containing IDs to record.
156171
states (torch.Tensor): The embeddings or states corresponding to the IDs in the kjt.
157172
"""
@@ -162,7 +177,9 @@ def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
162177
# In EMBEDDING mode, we track per feature IDs and corresponding embeddings received in the current batch.
163178
elif self._mode == TrackingMode.EMBEDDING:
164179
self.record_embeddings(kjt, states)
165-
180+
# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181+
elif self._mode == TrackingMode.MOMENTUM_LAST:
182+
self.record_momentum(emb_module, kjt)
166183
else:
167184
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
168185

@@ -228,6 +245,39 @@ def record_embeddings(
228245
states=torch.cat(per_table_emb[table_fqn]),
229246
)
230247

248+
def record_momentum(
249+
self,
250+
emb_module: nn.Module,
251+
kjt: KeyedJaggedTensor,
252+
) -> None:
253+
# FIXME: this is the momentum from last iteration, use momentum from current iter
254+
# for correctness.
255+
# pyre-ignore Undefined attribute [16]:
256+
momentum = emb_module._emb_module.momentum1_dev
257+
# FIXME: support multiple tables per group, information can be extracted from
258+
# module._config (i.e., GroupedEmbeddingConfig)
259+
# pyre-ignore Undefined attribute [16]:
260+
states = momentum.view(-1, emb_module._config.embedding_dims()[0])[
261+
kjt.values()
262+
].norm(dim=1)
263+
264+
offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
265+
torch.tensor(kjt.length_per_key(), dtype=torch.int64)
266+
)
267+
assert (
268+
kjt.values().numel() == states.numel()
269+
), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} "
270+
271+
for i, key in enumerate(kjt.keys()):
272+
fqn = self.feature_to_fqn[key]
273+
per_key_states = states[offsets[i] : offsets[i + 1]]
274+
self.store.append(
275+
batch_idx=self.curr_batch_idx,
276+
table_fqn=fqn,
277+
ids=kjt[key].values(),
278+
states=per_key_states,
279+
)
280+
231281
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
232282
"""
233283
Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -380,13 +430,31 @@ def _clean_fqn_fn(self, fqn: str) -> str:
380430
def _validate_and_init_tracker_fns(self) -> None:
381431
"To validate the mode is supported for the given module"
382432
for module in self.tracked_modules.values():
433+
# EMBEDDING mode is only supported for ShardedEmbeddingCollection
383434
assert not (
384435
isinstance(module, ShardedEmbeddingBagCollection)
385436
and self._mode == TrackingMode.EMBEDDING
386437
), "EBC's lookup returns pooled embeddings and currently, we do not support tracking raw embeddings."
387-
# register post lookup function
388-
# pyre-ignore[29]
389-
module.register_post_lookup_tracker_fn(self.record_lookup)
438+
439+
if (
440+
self._mode == TrackingMode.ID_ONLY
441+
or self._mode == TrackingMode.EMBEDDING
442+
):
443+
# register post lookup function
444+
# pyre-ignore[29]
445+
module.register_post_lookup_tracker_fn(self.record_lookup)
446+
elif self._mode == TrackingMode.MOMENTUM_LAST:
447+
# pyre-ignore[29]:
448+
for lookup in module._lookups:
449+
assert isinstance(
450+
lookup,
451+
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
452+
)
453+
lookup.register_optim_state_tracker_fn(self.record_lookup)
454+
else:
455+
raise NotImplementedError(
456+
f"Tracking mode {self._mode} is not supported"
457+
)
390458
# register auto compaction function at odist
391459
if self._auto_compact:
392460
# pyre-ignore[29]

0 commit comments

Comments
 (0)