Skip to content

Commit 5e9763d

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Adding support for MOMENTUM_DIFF and ROWWISE_ADAGRAD optimizer states (#3144)
Summary: Pull Request resolved: #3144 X-link: #3144 This diff extends the Model Delta Tracker to support two new tracking modes: `MOMENTUM_DIFF` and `ROWWISE_ADAGRAD`, which enable tracking of rowwise optimizer states for more sophisticated gradient analysis. Differential Revision: D76918891 fbshipit-source-id: 05b63979a05e3d896c3c61c9fc8a56d2558220f3
1 parent 397a6f6 commit 5e9763d

File tree

3 files changed

+296
-12
lines changed

3 files changed

+296
-12
lines changed

torchrec/distributed/model_tracker/model_delta_tracker.py

Lines changed: 122 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,20 @@
88
# pyre-strict
99
import logging as logger
1010
from collections import Counter, OrderedDict
11-
from typing import Dict, Iterable, List, Optional
11+
from typing import Dict, Iterable, List, Optional, Tuple
1212

1313
import torch
14+
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType
15+
from fbgemm_gpu.split_table_batched_embeddings_ops import (
16+
SplitTableBatchedEmbeddingBagsCodegen,
17+
)
1418

1519
from torch import nn
20+
from torchrec.distributed.batched_embedding_kernel import BatchedFusedEmbedding
1621

1722
from torchrec.distributed.embedding import ShardedEmbeddingCollection
1823
from torchrec.distributed.embedding_lookup import (
24+
BatchedFusedEmbeddingBag,
1925
GroupedEmbeddingsLookup,
2026
GroupedPooledEmbeddingsLookup,
2127
)
@@ -26,6 +32,8 @@
2632
EmbdUpdateMode,
2733
TrackingMode,
2834
)
35+
from torchrec.distributed.utils import none_throws
36+
2937
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
3038

3139
UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = {
@@ -42,6 +50,14 @@
4250
# This mode supports approximate top-k delta-row selection, can be
4351
# obtained by running momentum.norm().topk().
4452
TrackingMode.MOMENTUM_LAST: EmbdUpdateMode.LAST,
53+
# MOMENTUM_DIFF keeps a running sum of the square of the gradients per row.
54+
# Within each publishing interval, we track the starting value of this running
55+
# sum on all used rows and then do a lookup when ``get_delta`` is called to query
56+
# the latest sum. Then we can compute the delta of the two values and return them
57+
# together with the row ids.
58+
TrackingMode.MOMENTUM_DIFF: EmbdUpdateMode.FIRST,
59+
# The same as MOMENTUM_DIFF. Adding for backward compatibility.
60+
TrackingMode.ROWWISE_ADAGRAD: EmbdUpdateMode.FIRST,
4561
}
4662

4763
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -99,6 +115,7 @@ def __init__(
99115

100116
# from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
101117
self.tracked_modules: Dict[str, nn.Module] = {}
118+
self.table_to_fqn: Dict[str, str] = {}
102119
self.feature_to_fqn: Dict[str, str] = {}
103120
# Generate the mapping from FQN to feature names.
104121
self.fqn_to_feature_names()
@@ -180,6 +197,11 @@ def record_lookup(
180197
# In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181198
elif self._mode == TrackingMode.MOMENTUM_LAST:
182199
self.record_momentum(emb_module, kjt)
200+
elif (
201+
self._mode == TrackingMode.MOMENTUM_DIFF
202+
or self._mode == TrackingMode.ROWWISE_ADAGRAD
203+
):
204+
self.record_rowwise_optim_state(emb_module, kjt)
183205
else:
184206
raise NotImplementedError(f"Tracking mode {self._mode} is not supported")
185207

@@ -278,6 +300,60 @@ def record_momentum(
278300
states=per_key_states,
279301
)
280302

303+
def record_rowwise_optim_state(
304+
self,
305+
emb_module: nn.Module,
306+
kjt: KeyedJaggedTensor,
307+
) -> None:
308+
opt_states: List[List[torch.Tensor]] = (
309+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
310+
# `split_optimizer_states`.
311+
emb_module._emb_module.split_optimizer_states()
312+
)
313+
proxy: torch.Tensor = torch.cat([state[0] for state in opt_states])
314+
states = proxy[kjt.values()]
315+
assert (
316+
kjt.values().numel() == states.numel()
317+
), f"number of ids and states mismatch, expect {kjt.values()=}, {kjt.values().numel()}, but got {states.numel()} "
318+
offsets: torch.Tensor = torch.ops.fbgemm.asynchronous_complete_cumsum(
319+
torch.tensor(kjt.length_per_key(), dtype=torch.int64)
320+
)
321+
for i, key in enumerate(kjt.keys()):
322+
fqn = self.feature_to_fqn[key]
323+
per_key_states = states[offsets[i] : offsets[i + 1]]
324+
self.store.append(
325+
batch_idx=self.curr_batch_idx,
326+
table_fqn=fqn,
327+
ids=kjt[key].values(),
328+
states=per_key_states,
329+
)
330+
331+
def get_latest(self) -> Dict[str, torch.Tensor]:
332+
ret: Dict[str, torch.Tensor] = {}
333+
for module in self.tracked_modules.values():
334+
# pyre-fixme[29]:
335+
for lookup in module._lookups:
336+
for embs_module in lookup._emb_modules:
337+
assert isinstance(
338+
embs_module, (BatchedFusedEmbeddingBag, BatchedFusedEmbedding)
339+
), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but {type(embs_module)} found"
340+
tbe = embs_module._emb_module
341+
342+
assert isinstance(tbe, SplitTableBatchedEmbeddingBagsCodegen)
343+
table_names = [t.name for t in embs_module._config.embedding_tables]
344+
opt_states = tbe.split_optimizer_states()
345+
assert len(table_names) == len(opt_states)
346+
347+
for i, table_name in enumerate(table_names):
348+
emb_fqn = self.table_to_fqn[table_name]
349+
table_state = opt_states[i][0]
350+
assert (
351+
emb_fqn not in ret
352+
), f"a table with {emb_fqn} already exists"
353+
ret[emb_fqn] = table_state
354+
355+
return ret
356+
281357
def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tensor]:
282358
"""
283359
Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -289,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso
289365
per_table_delta_rows = self.get_delta(consumer)
290366
return {fqn: delta_rows.ids for fqn, delta_rows in per_table_delta_rows.items()}
291367

292-
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
368+
def get_delta(
369+
self,
370+
consumer: Optional[str] = None,
371+
top_percentage: Optional[float] = 1.0,
372+
per_table_percentage: Optional[Dict[str, Tuple[float, str]]] = None,
373+
sorted_by_indices: Optional[bool] = True,
374+
) -> Dict[str, DeltaRows]:
293375
"""
294376
Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
295377
@@ -314,6 +396,17 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
314396
self.per_consumer_batch_idx[consumer] = index_end
315397
if self._delete_on_read:
316398
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))
399+
400+
if self._mode in (TrackingMode.MOMENTUM_DIFF, TrackingMode.ROWWISE_ADAGRAD):
401+
square_sum_map = self.get_latest()
402+
for fqn, rows in tracker_rows.items():
403+
assert (
404+
fqn in square_sum_map
405+
), f"{fqn} not found in {square_sum_map.keys()}"
406+
# pyre-fixme[58]: `-` is not supported for operand types `Tensor`
407+
# and `Optional[Tensor]`.
408+
rows.states = square_sum_map[fqn][rows.ids] - rows.states
409+
317410
return tracker_rows
318411

319412
def get_tracked_modules(self) -> Dict[str, nn.Module]:
@@ -330,7 +423,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
330423
return self._fqn_to_feature_map
331424

332425
table_to_feature_names: Dict[str, List[str]] = OrderedDict()
333-
table_to_fqn: Dict[str, str] = OrderedDict()
334426
for fqn, named_module in self._model.named_modules():
335427
split_fqn = fqn.split(".")
336428
# Skipping partial FQNs present in fqns_to_skip
@@ -356,13 +448,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
356448
# will incorrectly match fqn with all the table names that have the same prefix
357449
if table_name in split_fqn:
358450
embedding_fqn = self._clean_fqn_fn(fqn)
359-
if table_name in table_to_fqn:
451+
if table_name in self.table_to_fqn:
360452
# Sanity check for validating that we don't have more then one table mapping to same fqn.
361453
logger.warning(
362-
f"Override {table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
454+
f"Override {self.table_to_fqn[table_name]} with {embedding_fqn} for entry {table_name}"
363455
)
364-
table_to_fqn[table_name] = embedding_fqn
365-
logger.info(f"Table to fqn: {table_to_fqn}")
456+
self.table_to_fqn[table_name] = embedding_fqn
457+
logger.info(f"Table to fqn: {self.table_to_fqn}")
366458
flatten_names = [
367459
name for names in table_to_feature_names.values() for name in names
368460
]
@@ -375,15 +467,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
375467

376468
fqn_to_feature_names: Dict[str, List[str]] = OrderedDict()
377469
for table_name in table_to_feature_names:
378-
if table_name not in table_to_fqn:
470+
if table_name not in self.table_to_fqn:
379471
# This is likely unexpected, where we can't locate the FQN associated with this table.
380472
logger.warning(
381-
f"Table {table_name} not found in {table_to_fqn}, skipping"
473+
f"Table {table_name} not found in {self.table_to_fqn}, skipping"
382474
)
383475
continue
384-
fqn_to_feature_names[table_to_fqn[table_name]] = table_to_feature_names[
385-
table_name
386-
]
476+
fqn_to_feature_names[self.table_to_fqn[table_name]] = (
477+
table_to_feature_names[table_name]
478+
)
387479
self._fqn_to_feature_map = fqn_to_feature_names
388480
return fqn_to_feature_names
389481

@@ -451,6 +543,24 @@ def _validate_and_init_tracker_fns(self) -> None:
451543
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
452544
)
453545
lookup.register_optim_state_tracker_fn(self.record_lookup)
546+
elif (
547+
self._mode == TrackingMode.ROWWISE_ADAGRAD
548+
or self._mode == TrackingMode.MOMENTUM_DIFF
549+
):
550+
# pyre-ignore[29]:
551+
for lookup in module._lookups:
552+
assert isinstance(
553+
lookup,
554+
(GroupedEmbeddingsLookup, GroupedPooledEmbeddingsLookup),
555+
) and all(
556+
# TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
557+
# pyre-ignore[16]:
558+
emb._emb_module.optimizer == OptimType.EXACT_ROWWISE_ADAGRAD
559+
# pyre-ignore[16]:
560+
or emb._emb_module.optimizer == OptimType.PARTIAL_ROWWISE_ADAM
561+
for emb in lookup._emb_modules
562+
)
563+
lookup.register_optim_state_tracker_fn(self.record_lookup)
454564
else:
455565
raise NotImplementedError(
456566
f"Tracking mode {self._mode} is not supported"

0 commit comments

Comments
 (0)