Skip to content

Commit 0e20667

Browse files
aliafzalmeta-codesync[bot]
authored andcommitted
Allow torchrec delta tracker to be initialized post init (#3472)
Summary: Pull Request resolved: #3472 internal General Context: We are in the process of transition to a unified DeltaTracker and this is 6/n diffs representing changes towards the transition. Context: MRS DeltaTracker is initialized right before training, to allow for OSS DeltaTracker to have similar behavior adding a post DMP init function for initializing ModelDeltaTracker if not initialized. Differential Revision: D80615308 fbshipit-source-id: 10c73274793580e2c1d0e7e6efd9377cab66fd99
1 parent c2414ac commit 0e20667

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

torchrec/distributed/model_parallel.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
append_prefix,
5151
copy_to_device,
5252
filter_state_dict,
53+
none_throws,
5354
sharded_model_copy,
5455
)
5556
from torchrec.optim.fused import FusedOptimizerModule
@@ -456,6 +457,19 @@ def init_parameters(module: nn.Module) -> None:
456457

457458
module.apply(init_parameters)
458459

460+
def init_torchrec_delta_tracker(
461+
self, model_tracker_config: ModelTrackerConfig
462+
) -> ModelDeltaTrackerTrec:
463+
"""
464+
Initializes the model delta tracker if it doesn't exists.
465+
"""
466+
if self.model_delta_tracker is None:
467+
self.model_delta_tracker = self._init_delta_tracker(
468+
model_tracker_config, self._dmp_wrapped_module
469+
)
470+
471+
return none_throws(self.model_delta_tracker)
472+
459473
def get_model_tracker(self) -> ModelDeltaTrackerTrec:
460474
"""
461475
Returns the model tracker if it exists.

0 commit comments

Comments
 (0)