Commit 397a6f6
Adding support for tracking optimizers states in Model Delta Tracker. (#3143)
Summary:
Pull Request resolved: #3143
X-link: #3143
### Overview
This diff adds support for tracking optimizer states in the Model Delta Tracker system. It introduces a new tracking mode called `MOMENTUM_LAST` that enables tracking of momentum values from optimizers to support approximate top-k delta-row selection.
### Key Changes
#### 1. Optimizer State Tracking Support
* To support tracking of optimizer states I have added `optim_state_tracker_fn` attribute to `GroupedEmbeddingsLookup` and `GroupedPooledEmbeddingsLookup` classes responsible for traversing over the BatchedFused modules.
* Implemented `register_optim_state_tracker_fn()` method in both classes to register the trackable callable
* Tracking calls are invoked after each lookup operation.
#### 2. Model Delta Tracker Changes
* Added `record_momentum()` method to track momentum values from optimizer states and its support in record_lookup function.
* Added validation and optim tracker function logic to support the new `MOMENTUM_LAST` mode
#### 3. New Tracking Mode
* Added `TrackingMode.MOMENTUM_LAST` to [`**types.py**`](command:code-compose.open?%5B%22%2Ffbcode%2Ftorchrec%2Fdistributed%2Fmodel_tracker%2Ftypes.py%22%2Cnull%5D "/fbcode/torchrec/distributed/model_tracker/types.py")
* Maps to `EmbdUpdateMode.LAST` to capture the most recent momentum values
Differential Revision: D76868111
fbshipit-source-id: bde3d4be8d3df7fe5b2f284a262c50a5313c1dc01 parent 80dbb88 commit 397a6f6
File tree
7 files changed
+279
-18
lines changed- torchrec/distributed
- model_tracker
- tests
7 files changed
+279
-18
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1587 | 1587 | | |
1588 | 1588 | | |
1589 | 1589 | | |
1590 | | - | |
| 1590 | + | |
1591 | 1591 | | |
1592 | 1592 | | |
1593 | 1593 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | | - | |
| 13 | + | |
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| |||
208 | 208 | | |
209 | 209 | | |
210 | 210 | | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
211 | 215 | | |
212 | 216 | | |
213 | 217 | | |
| |||
315 | 319 | | |
316 | 320 | | |
317 | 321 | | |
318 | | - | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
319 | 329 | | |
320 | 330 | | |
321 | 331 | | |
| |||
420 | 430 | | |
421 | 431 | | |
422 | 432 | | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
423 | 446 | | |
424 | 447 | | |
425 | 448 | | |
| |||
519 | 542 | | |
520 | 543 | | |
521 | 544 | | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
522 | 549 | | |
523 | 550 | | |
524 | 551 | | |
| |||
678 | 705 | | |
679 | 706 | | |
680 | 707 | | |
681 | | - | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
682 | 714 | | |
683 | 715 | | |
684 | 716 | | |
| |||
811 | 843 | | |
812 | 844 | | |
813 | 845 | | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
814 | 859 | | |
815 | 860 | | |
816 | 861 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
391 | 391 | | |
392 | 392 | | |
393 | 393 | | |
394 | | - | |
| 394 | + | |
395 | 395 | | |
396 | 396 | | |
397 | 397 | | |
| |||
444 | 444 | | |
445 | 445 | | |
446 | 446 | | |
447 | | - | |
| 447 | + | |
448 | 448 | | |
449 | 449 | | |
450 | 450 | | |
451 | 451 | | |
452 | 452 | | |
453 | 453 | | |
454 | | - | |
| 454 | + | |
455 | 455 | | |
456 | 456 | | |
457 | 457 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1671 | 1671 | | |
1672 | 1672 | | |
1673 | 1673 | | |
1674 | | - | |
| 1674 | + | |
1675 | 1675 | | |
1676 | 1676 | | |
1677 | 1677 | | |
| |||
Lines changed: 75 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
16 | 17 | | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
17 | 22 | | |
18 | 23 | | |
19 | 24 | | |
| |||
27 | 32 | | |
28 | 33 | | |
29 | 34 | | |
30 | | - | |
31 | | - | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
32 | 38 | | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
33 | 45 | | |
34 | 46 | | |
35 | 47 | | |
| |||
141 | 153 | | |
142 | 154 | | |
143 | 155 | | |
144 | | - | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
145 | 159 | | |
146 | 160 | | |
147 | 161 | | |
| |||
152 | 166 | | |
153 | 167 | | |
154 | 168 | | |
| 169 | + | |
155 | 170 | | |
156 | 171 | | |
157 | 172 | | |
| |||
162 | 177 | | |
163 | 178 | | |
164 | 179 | | |
165 | | - | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
166 | 183 | | |
167 | 184 | | |
168 | 185 | | |
| |||
228 | 245 | | |
229 | 246 | | |
230 | 247 | | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
231 | 281 | | |
232 | 282 | | |
233 | 283 | | |
| |||
380 | 430 | | |
381 | 431 | | |
382 | 432 | | |
| 433 | + | |
383 | 434 | | |
384 | 435 | | |
385 | 436 | | |
386 | 437 | | |
387 | | - | |
388 | | - | |
389 | | - | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
390 | 458 | | |
391 | 459 | | |
392 | 460 | | |
| |||
0 commit comments