1313
1414import torch
1515from torchrec .distributed .model_tracker .types import (
16- DeltaRows ,
17- EmbdUpdateMode ,
1816 IndexedLookup ,
17+ UniqueRows ,
18+ UpdateMode ,
1919)
2020from torchrec .distributed .utils import none_throws
2121
2222
2323def _compute_unique_rows (
2424 ids : List [torch .Tensor ],
2525 states : Optional [List [torch .Tensor ]],
26- mode : EmbdUpdateMode ,
27- ) -> DeltaRows :
26+ mode : UpdateMode ,
27+ ) -> UniqueRows :
2828 r"""
2929 To calculate unique ids and embeddings
3030 """
31- if mode == EmbdUpdateMode .NONE :
32- assert states is None , f"{ mode = } == EmbdUpdateMode .NONE but received embeddings"
31+ if mode == UpdateMode .NONE :
32+ assert states is None , f"{ mode = } == UpdateMode .NONE but received embeddings"
3333 unique_ids = torch .cat (ids ).unique (return_inverse = False )
34- return DeltaRows (ids = unique_ids , states = None )
34+ return UniqueRows (ids = unique_ids , states = None )
3535 else :
3636 assert (
3737 states is not None
38- ), f"{ mode = } != EmbdUpdateMode .NONE but received no embeddings"
38+ ), f"{ mode = } != UpdateMode .NONE but received no embeddings"
3939
4040 cat_ids = torch .cat (ids )
4141 cat_states = torch .cat (states )
4242
43- if mode == EmbdUpdateMode .LAST :
43+ if mode == UpdateMode .LAST :
4444 cat_ids = cat_ids .flip (dims = [0 ])
4545 cat_states = cat_states .flip (dims = [0 ])
4646
@@ -65,7 +65,7 @@ def _compute_unique_rows(
6565
6666 # Use first occurrence indices to select corresponding embedding row.
6767 unique_states = cat_states [first_occurrence ]
68- return DeltaRows (ids = unique_ids , states = unique_states )
68+ return UniqueRows (ids = unique_ids , states = unique_states )
6969
7070
7171class DeltaStore (ABC ):
@@ -81,7 +81,7 @@ class DeltaStore(ABC):
8181 """
8282
8383 @abstractmethod
84- def __init__ (self , embdUpdateMode : EmbdUpdateMode = EmbdUpdateMode .NONE ) -> None :
84+ def __init__ (self , updateMode : UpdateMode = UpdateMode .NONE ) -> None :
8585 pass
8686
8787 @abstractmethod
@@ -125,7 +125,7 @@ def compact(self, start_idx: int, end_idx: int) -> None:
125125 pass
126126
127127 @abstractmethod
128- def get_unique (self , from_idx : int = 0 ) -> Dict [str , DeltaRows ]:
128+ def get_unique (self , from_idx : int = 0 ) -> Dict [str , UniqueRows ]:
129129 """
130130 Return all unique/delta ids per table from the Delta Store.
131131
@@ -151,9 +151,9 @@ class DeltaStoreTrec(DeltaStore):
151151 how to handle duplicate ids when compacting or retrieving embeddings.
152152 """
153153
154- def __init__ (self , embdUpdateMode : EmbdUpdateMode = EmbdUpdateMode .NONE ) -> None :
155- super ().__init__ (embdUpdateMode )
156- self .embdUpdateMode = embdUpdateMode
154+ def __init__ (self , updateMode : UpdateMode = UpdateMode .NONE ) -> None :
155+ super ().__init__ (updateMode )
156+ self .updateMode = updateMode
157157 self .per_fqn_lookups : Dict [str , List [IndexedLookup ]] = {}
158158
159159 def append (
@@ -205,11 +205,11 @@ def compact(self, start_idx: int, end_idx: int) -> None:
205205 ids = [lookup .ids for lookup in lookups_to_compact ]
206206 states = (
207207 [none_throws (lookup .states ) for lookup in lookups_to_compact ]
208- if self .embdUpdateMode != EmbdUpdateMode .NONE
208+ if self .updateMode != UpdateMode .NONE
209209 else None
210210 )
211211 delta_rows = _compute_unique_rows (
212- ids = ids , states = states , mode = self .embdUpdateMode
212+ ids = ids , states = states , mode = self .updateMode
213213 )
214214 new_per_fqn_lookups [table_fqn ] = (
215215 lookups [:index_l ]
@@ -224,12 +224,12 @@ def compact(self, start_idx: int, end_idx: int) -> None:
224224 )
225225 self .per_fqn_lookups = new_per_fqn_lookups
226226
227- def get_unique (self , from_idx : int = 0 ) -> Dict [str , DeltaRows ]:
227+ def get_unique (self , from_idx : int = 0 ) -> Dict [str , UniqueRows ]:
228228 r"""
229229 Return all unique/delta ids per table from the Delta Store.
230230 """
231231
232- delta_per_table_fqn : Dict [str , DeltaRows ] = {}
232+ delta_per_table_fqn : Dict [str , UniqueRows ] = {}
233233 for table_fqn , lookups in self .per_fqn_lookups .items ():
234234 compact_ids = [
235235 lookup .ids for lookup in lookups if lookup .batch_idx >= from_idx
@@ -240,11 +240,11 @@ def get_unique(self, from_idx: int = 0) -> Dict[str, DeltaRows]:
240240 for lookup in lookups
241241 if lookup .batch_idx >= from_idx
242242 ]
243- if self .embdUpdateMode != EmbdUpdateMode .NONE
243+ if self .updateMode != UpdateMode .NONE
244244 else None
245245 )
246246
247247 delta_per_table_fqn [table_fqn ] = _compute_unique_rows (
248- ids = compact_ids , states = compact_states , mode = self .embdUpdateMode
248+ ids = compact_ids , states = compact_states , mode = self .updateMode
249249 )
250250 return delta_per_table_fqn
0 commit comments