88# pyre-strict
99import logging as logger
1010from collections import Counter , OrderedDict
11- from typing import Dict , Iterable , List , Optional
11+ from typing import Dict , Iterable , List , Optional , Tuple
1212
1313import torch
14+ from fbgemm_gpu .split_embedding_configs import EmbOptimType as OptimType
15+ from fbgemm_gpu .split_table_batched_embeddings_ops import (
16+ SplitTableBatchedEmbeddingBagsCodegen ,
17+ )
1418
1519from torch import nn
20+ from torchrec .distributed .batched_embedding_kernel import BatchedFusedEmbedding
1621
1722from torchrec .distributed .embedding import ShardedEmbeddingCollection
1823from torchrec .distributed .embedding_lookup import (
24+ BatchedFusedEmbeddingBag ,
1925 GroupedEmbeddingsLookup ,
2026 GroupedPooledEmbeddingsLookup ,
2127)
2632 EmbdUpdateMode ,
2733 TrackingMode ,
2834)
35+ from torchrec .distributed .utils import none_throws
36+
2937from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
3038
3139UPDATE_MODE_MAP : Dict [TrackingMode , EmbdUpdateMode ] = {
4250 # This mode supports approximate top-k delta-row selection, can be
4351 # obtained by running momentum.norm().topk().
4452 TrackingMode .MOMENTUM_LAST : EmbdUpdateMode .LAST ,
53+ # MOMENTUM_DIFF keeps a running sum of the square of the gradients per row.
54+ # Within each publishing interval, we track the starting value of this running
55+ # sum on all used rows and then do a lookup when ``get_delta`` is called to query
56+ # the latest sum. Then we can compute the delta of the two values and return them
57+ # together with the row ids.
58+ TrackingMode .MOMENTUM_DIFF : EmbdUpdateMode .FIRST ,
59+ # The same as MOMENTUM_DIFF. Adding for backward compatibility.
60+ TrackingMode .ROWWISE_ADAGRAD : EmbdUpdateMode .FIRST ,
4561}
4662
4763# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
@@ -99,6 +115,7 @@ def __init__(
99115
100116 # from module FQN to ShardedEmbeddingCollection/ShardedEmbeddingBagCollection
101117 self .tracked_modules : Dict [str , nn .Module ] = {}
118+ self .table_to_fqn : Dict [str , str ] = {}
102119 self .feature_to_fqn : Dict [str , str ] = {}
103120 # Generate the mapping from FQN to feature names.
104121 self .fqn_to_feature_names ()
@@ -180,6 +197,11 @@ def record_lookup(
180197 # In MOMENTUM_LAST mode, we track per feature IDs and corresponding momentum values received in the current batch.
181198 elif self ._mode == TrackingMode .MOMENTUM_LAST :
182199 self .record_momentum (emb_module , kjt )
200+ elif (
201+ self ._mode == TrackingMode .MOMENTUM_DIFF
202+ or self ._mode == TrackingMode .ROWWISE_ADAGRAD
203+ ):
204+ self .record_rowwise_optim_state (emb_module , kjt )
183205 else :
184206 raise NotImplementedError (f"Tracking mode { self ._mode } is not supported" )
185207
@@ -278,6 +300,60 @@ def record_momentum(
278300 states = per_key_states ,
279301 )
280302
303+ def record_rowwise_optim_state (
304+ self ,
305+ emb_module : nn .Module ,
306+ kjt : KeyedJaggedTensor ,
307+ ) -> None :
308+ opt_states : List [List [torch .Tensor ]] = (
309+ # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
310+ # `split_optimizer_states`.
311+ emb_module ._emb_module .split_optimizer_states ()
312+ )
313+ proxy : torch .Tensor = torch .cat ([state [0 ] for state in opt_states ])
314+ states = proxy [kjt .values ()]
315+ assert (
316+ kjt .values ().numel () == states .numel ()
317+ ), f"number of ids and states mismatch, expect { kjt .values ()= } , { kjt .values ().numel ()} , but got { states .numel ()} "
318+ offsets : torch .Tensor = torch .ops .fbgemm .asynchronous_complete_cumsum (
319+ torch .tensor (kjt .length_per_key (), dtype = torch .int64 )
320+ )
321+ for i , key in enumerate (kjt .keys ()):
322+ fqn = self .feature_to_fqn [key ]
323+ per_key_states = states [offsets [i ] : offsets [i + 1 ]]
324+ self .store .append (
325+ batch_idx = self .curr_batch_idx ,
326+ table_fqn = fqn ,
327+ ids = kjt [key ].values (),
328+ states = per_key_states ,
329+ )
330+
331+ def get_latest (self ) -> Dict [str , torch .Tensor ]:
332+ ret : Dict [str , torch .Tensor ] = {}
333+ for module in self .tracked_modules .values ():
334+ # pyre-fixme[29]:
335+ for lookup in module ._lookups :
336+ for embs_module in lookup ._emb_modules :
337+ assert isinstance (
338+ embs_module , (BatchedFusedEmbeddingBag , BatchedFusedEmbedding )
339+ ), f"expect BatchedFusedEmbeddingBag or BatchedFusedEmbedding, but { type (embs_module )} found"
340+ tbe = embs_module ._emb_module
341+
342+ assert isinstance (tbe , SplitTableBatchedEmbeddingBagsCodegen )
343+ table_names = [t .name for t in embs_module ._config .embedding_tables ]
344+ opt_states = tbe .split_optimizer_states ()
345+ assert len (table_names ) == len (opt_states )
346+
347+ for i , table_name in enumerate (table_names ):
348+ emb_fqn = self .table_to_fqn [table_name ]
349+ table_state = opt_states [i ][0 ]
350+ assert (
351+ emb_fqn not in ret
352+ ), f"a table with { emb_fqn } already exists"
353+ ret [emb_fqn ] = table_state
354+
355+ return ret
356+
281357 def get_delta_ids (self , consumer : Optional [str ] = None ) -> Dict [str , torch .Tensor ]:
282358 """
283359 Return a dictionary of hit local IDs for each sparse feature. Ids are
@@ -289,7 +365,13 @@ def get_delta_ids(self, consumer: Optional[str] = None) -> Dict[str, torch.Tenso
289365 per_table_delta_rows = self .get_delta (consumer )
290366 return {fqn : delta_rows .ids for fqn , delta_rows in per_table_delta_rows .items ()}
291367
292- def get_delta (self , consumer : Optional [str ] = None ) -> Dict [str , DeltaRows ]:
368+ def get_delta (
369+ self ,
370+ consumer : Optional [str ] = None ,
371+ top_percentage : Optional [float ] = 1.0 ,
372+ per_table_percentage : Optional [Dict [str , Tuple [float , str ]]] = None ,
373+ sorted_by_indices : Optional [bool ] = True ,
374+ ) -> Dict [str , DeltaRows ]:
293375 """
294376 Return a dictionary of hit local IDs and parameter states / embeddings for each sparse feature. The Values are first keyed by submodule FQN.
295377
@@ -314,6 +396,17 @@ def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
314396 self .per_consumer_batch_idx [consumer ] = index_end
315397 if self ._delete_on_read :
316398 self .store .delete (up_to_idx = min (self .per_consumer_batch_idx .values ()))
399+
400+ if self ._mode in (TrackingMode .MOMENTUM_DIFF , TrackingMode .ROWWISE_ADAGRAD ):
401+ square_sum_map = self .get_latest ()
402+ for fqn , rows in tracker_rows .items ():
403+ assert (
404+ fqn in square_sum_map
405+ ), f"{ fqn } not found in { square_sum_map .keys ()} "
406+ # pyre-fixme[58]: `-` is not supported for operand types `Tensor`
407+ # and `Optional[Tensor]`.
408+ rows .states = square_sum_map [fqn ][rows .ids ] - rows .states
409+
317410 return tracker_rows
318411
319412 def get_tracked_modules (self ) -> Dict [str , nn .Module ]:
@@ -330,7 +423,6 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
330423 return self ._fqn_to_feature_map
331424
332425 table_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
333- table_to_fqn : Dict [str , str ] = OrderedDict ()
334426 for fqn , named_module in self ._model .named_modules ():
335427 split_fqn = fqn .split ("." )
336428 # Skipping partial FQNs present in fqns_to_skip
@@ -356,13 +448,13 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
356448 # will incorrectly match fqn with all the table names that have the same prefix
357449 if table_name in split_fqn :
358450 embedding_fqn = self ._clean_fqn_fn (fqn )
359- if table_name in table_to_fqn :
451+ if table_name in self . table_to_fqn :
360452 # Sanity check for validating that we don't have more then one table mapping to same fqn.
361453 logger .warning (
362- f"Override { table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
454+ f"Override { self . table_to_fqn [table_name ]} with { embedding_fqn } for entry { table_name } "
363455 )
364- table_to_fqn [table_name ] = embedding_fqn
365- logger .info (f"Table to fqn: { table_to_fqn } " )
456+ self . table_to_fqn [table_name ] = embedding_fqn
457+ logger .info (f"Table to fqn: { self . table_to_fqn } " )
366458 flatten_names = [
367459 name for names in table_to_feature_names .values () for name in names
368460 ]
@@ -375,15 +467,15 @@ def fqn_to_feature_names(self) -> Dict[str, List[str]]:
375467
376468 fqn_to_feature_names : Dict [str , List [str ]] = OrderedDict ()
377469 for table_name in table_to_feature_names :
378- if table_name not in table_to_fqn :
470+ if table_name not in self . table_to_fqn :
379471 # This is likely unexpected, where we can't locate the FQN associated with this table.
380472 logger .warning (
381- f"Table { table_name } not found in { table_to_fqn } , skipping"
473+ f"Table { table_name } not found in { self . table_to_fqn } , skipping"
382474 )
383475 continue
384- fqn_to_feature_names [table_to_fqn [table_name ]] = table_to_feature_names [
385- table_name
386- ]
476+ fqn_to_feature_names [self . table_to_fqn [table_name ]] = (
477+ table_to_feature_names [ table_name ]
478+ )
387479 self ._fqn_to_feature_map = fqn_to_feature_names
388480 return fqn_to_feature_names
389481
@@ -451,6 +543,24 @@ def _validate_and_init_tracker_fns(self) -> None:
451543 (GroupedEmbeddingsLookup , GroupedPooledEmbeddingsLookup ),
452544 )
453545 lookup .register_optim_state_tracker_fn (self .record_lookup )
546+ elif (
547+ self ._mode == TrackingMode .ROWWISE_ADAGRAD
548+ or self ._mode == TrackingMode .MOMENTUM_DIFF
549+ ):
550+ # pyre-ignore[29]:
551+ for lookup in module ._lookups :
552+ assert isinstance (
553+ lookup ,
554+ (GroupedEmbeddingsLookup , GroupedPooledEmbeddingsLookup ),
555+ ) and all (
556+ # TorchRec maps ROWWISE_ADAGRAD to EXACT_ROWWISE_ADAGRAD
557+ # pyre-ignore[16]:
558+ emb ._emb_module .optimizer == OptimType .EXACT_ROWWISE_ADAGRAD
559+ # pyre-ignore[16]:
560+ or emb ._emb_module .optimizer == OptimType .PARTIAL_ROWWISE_ADAM
561+ for emb in lookup ._emb_modules
562+ )
563+ lookup .register_optim_state_tracker_fn (self .record_lookup )
454564 else :
455565 raise NotImplementedError (
456566 f"Tracking mode { self ._mode } is not supported"
0 commit comments