|
7 | 7 |
|
8 | 8 |
|
9 | 9 | # pyre-strict |
| 10 | +from abc import ABC, abstractmethod |
10 | 11 | from bisect import bisect_left |
11 | 12 | from typing import Dict, List, Optional |
12 | 13 |
|
@@ -67,34 +68,106 @@ def _compute_unique_rows( |
67 | 68 | return DeltaRows(ids=unique_ids, states=unique_states) |
68 | 69 |
|
69 | 70 |
|
70 | | -class DeltaStore: |
| 71 | +class DeltaStore(ABC): |
71 | 72 | """ |
72 | | - DeltaStore is a helper class that stores and manages local delta (row) updates for embeddings/states across |
73 | | - various batches during training, designed to be used with TorchRecs ModelDeltaTracker. |
74 | | - It maintains a CUDA in-memory representation of requested ids and embeddings/states, |
| 73 | + DeltaStore is an abstract base class that defines the interface for storing and managing |
| 74 | + local delta (row) updates for embeddings/states across various batches during training. |
| 75 | +
|
| 76 | + Implementations should maintain a representation of requested ids and embeddings/states, |
75 | 77 | providing a way to compact and get delta updates for each embedding table. |
76 | 78 |
|
77 | 79 | The class supports different embedding update modes (NONE, FIRST, LAST) to determine |
78 | 80 | how to handle duplicate ids when compacting or retrieving embeddings. |
| 81 | + """ |
| 82 | + |
| 83 | + @abstractmethod |
| 84 | + def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: |
| 85 | + pass |
| 86 | + |
| 87 | + @abstractmethod |
| 88 | + def append( |
| 89 | + self, |
| 90 | + batch_idx: int, |
| 91 | + fqn: str, |
| 92 | + ids: torch.Tensor, |
| 93 | + states: Optional[torch.Tensor], |
| 94 | + ) -> None: |
| 95 | + """ |
| 96 | + Append a batch of ids and states to the store for a specific table. |
| 97 | +
|
| 98 | + Args: |
| 99 | + batch_idx: The batch index |
| 100 | + table_fqn: The fully qualified name of the table |
| 101 | + ids: The tensor of ids to append |
| 102 | + states: Optional tensor of states to append |
| 103 | + """ |
| 104 | + pass |
| 105 | + |
| 106 | + @abstractmethod |
| 107 | + def delete(self, up_to_idx: Optional[int] = None) -> None: |
| 108 | + """ |
| 109 | + Delete all idx from the store up to `up_to_idx` |
| 110 | +
|
| 111 | + Args: |
| 112 | + up_to_idx: Optional index up to which to delete lookups |
| 113 | + """ |
| 114 | + pass |
79 | 115 |
|
| 116 | + @abstractmethod |
| 117 | + def compact(self, start_idx: int, end_idx: int) -> None: |
| 118 | + """ |
| 119 | + Compact (ids, embeddings) in batch index range from start_idx to end_idx. |
| 120 | +
|
| 121 | + Args: |
| 122 | + start_idx: The starting batch index |
| 123 | + end_idx: The ending batch index |
| 124 | + """ |
| 125 | + pass |
| 126 | + |
| 127 | + @abstractmethod |
| 128 | + def get_delta(self, from_idx: int = 0) -> Dict[str, DeltaRows]: |
| 129 | + """ |
| 130 | + Return all unique/delta ids per table from the Delta Store. |
| 131 | +
|
| 132 | + Args: |
| 133 | + from_idx: The batch index from which to get deltas |
| 134 | +
|
| 135 | + Returns: |
| 136 | + A dictionary mapping table FQNs to their delta rows |
| 137 | + """ |
| 138 | + pass |
| 139 | + |
| 140 | + |
| 141 | +class DeltaStoreTrec(DeltaStore): |
| 142 | + """ |
| 143 | + DeltaStoreTrec is a concrete implementation of DeltaStore that stores and manages |
| 144 | + local delta (row) updates for embeddings/states across various batches during training, |
| 145 | + designed to be used with TorchRecs ModelDeltaTracker. |
| 146 | +
|
| 147 | + It maintains a CUDA in-memory representation of requested ids and embeddings/states, |
| 148 | + providing a way to compact and get delta updates for each embedding table. |
| 149 | +
|
| 150 | + The class supports different embedding update modes (NONE, FIRST, LAST) to determine |
| 151 | + how to handle duplicate ids when compacting or retrieving embeddings. |
80 | 152 | """ |
81 | 153 |
|
82 | 154 | def __init__(self, embdUpdateMode: EmbdUpdateMode = EmbdUpdateMode.NONE) -> None: |
| 155 | + super().__init__(embdUpdateMode) |
83 | 156 | self.embdUpdateMode = embdUpdateMode |
84 | 157 | self.per_fqn_lookups: Dict[str, List[IndexedLookup]] = {} |
85 | 158 |
|
86 | 159 | def append( |
87 | 160 | self, |
88 | 161 | batch_idx: int, |
89 | | - table_fqn: str, |
| 162 | + fqn: str, |
90 | 163 | ids: torch.Tensor, |
91 | 164 | states: Optional[torch.Tensor], |
92 | 165 | ) -> None: |
93 | | - table_fqn_lookup = self.per_fqn_lookups.get(table_fqn, []) |
| 166 | + table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) |
94 | 167 | table_fqn_lookup.append( |
95 | 168 | IndexedLookup(batch_idx=batch_idx, ids=ids, states=states) |
96 | 169 | ) |
97 | | - self.per_fqn_lookups[table_fqn] = table_fqn_lookup |
| 170 | + self.per_fqn_lookups[fqn] = table_fqn_lookup |
98 | 171 |
|
99 | 172 | def delete(self, up_to_idx: Optional[int] = None) -> None: |
100 | 173 | """ |
|
0 commit comments