From 43891bc4b6d1057a12f69d21322a960b88c33387 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 27 Oct 2025 12:48:01 -0700 Subject: [PATCH] fix the fx tracer leaf module logic (#3487) Summary: # context * in response to a TorchRec User Group [post](https://fb.workplace.com/groups/970281557043698/permalink/1984937168911460/) * in torchrec.train_pipeline, the modified fx tracer the algorithm is based on "named_modules()" API from torch.nn.Module, there could be corner cases that a returned "leaf_module" is not actually called in the forward pass. for example the umbrella_module would be considered as the top-level leaf module but it actually won't appear in a fx-traced graph. # example * traced graph ``` opcode name target args kwargs ------------- ------------ ----------------------- ---------------------------- -------- placeholder x x () {} call_module nested nested (x,) {} call_module umbrella1_m1 umbrella1.m1 (x,) {} call_module umbrella1_m2 umbrella1.m2 (x,) {} call_function add (umbrella1_m1, umbrella1_m2) {} call_function add_1 (nested, add) {} call_module umbrella2_m1 umbrella2.m1 (x,) {} call_function add_2 (add_1, umbrella2_m1) {} call_module umbrella2_m2 umbrella2.m2 (x,) {} call_function add_3 (add_2, umbrella2_m2) {} call_module umbrella3 umbrella3 (x,) {} call_function add_4 (add_3, umbrella3) {} call_module umbrella4_m1 umbrella4.m1 (x,) {} call_function add_5 (add_4, umbrella4_m1) {} call_module umbrella4_m2 umbrella4.m2 (x,) {} call_function add_6 (add_5, umbrella4_m2) {} output output output (add_6,) {} ``` Differential Revision: D80858379 --- .../train_pipeline/tests/test_tracing.py | 126 +++++++++++++++++- .../distributed/train_pipeline/tracing.py | 52 +++++++- 2 files changed, 176 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tests/test_tracing.py b/torchrec/distributed/train_pipeline/tests/test_tracing.py index 13103edf0..e2c19b6cf 100644 --- a/torchrec/distributed/train_pipeline/tests/test_tracing.py +++ b/torchrec/distributed/train_pipeline/tests/test_tracing.py @@ -8,19 +8,23 @@ # pyre-strict import unittest -from typing import List +from typing import List, Optional from unittest.mock import MagicMock import parameterized import torch +from torch import nn from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext from torchrec.distributed.train_pipeline.tracing import ( + _get_leaf_module_names, ArgInfo, ArgInfoStepFactory, CallArgs, NodeArgsHelper, + Tracer, ) +from torchrec.distributed.types import NullShardedModuleContext, ShardedModule from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -110,3 +114,123 @@ def test_get_node_args_helper_call_module_kjt(self) -> None: # Weights is call_module node, so we should only find 2 args unmodified self.assertEqual(num_found, len(kjt_args) - 1) + + +class DummyShardedModule( + ShardedModule[torch.Tensor, torch.Tensor, torch.Tensor, NullShardedModuleContext] +): + def __init__(self, alpha: float = 1) -> None: + super().__init__() + self.alpha = alpha + + # pyre-ignore + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.alpha * x + + # pyre-ignore + def compute(self) -> torch.Tensor: + return torch.empty(0) + + def create_context(self) -> NullShardedModuleContext: + return NullShardedModuleContext() + + # pyre-ignore + def input_dist(self, ctx: NullShardedModuleContext): + pass + + # pyre-ignore + def output_dist(self): + pass + + # pyre-ignore + def unsharded_module_type(self): + pass + + +class DummyUmbrellaModule(nn.Module): + def __init__(self, m1: nn.Module, m2: nn.Module) -> None: + super().__init__() + self.m1 = m1 + self.m2 = m2 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.m1(x) + self.m2(x) + + +class DummyNestedModule(nn.Module): + def __init__(self, layer: int = 0) -> None: + super().__init__() + self.layer = layer + self.inner: Optional[nn.Module] = ( + DummyNestedModule(layer - 1) if layer > 0 else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + inner = 0 if self.inner is None else self.inner(x) + return inner + 10**self.layer + + +class TestFxTracer(unittest.TestCase): + @classmethod + def _generate_sharded_model(cls) -> nn.Module: + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.nested = DummyNestedModule(3) + self.umbrella1 = DummyUmbrellaModule( + DummyNestedModule(2), DummyShardedModule() + ) + self.umbrella2 = DummyUmbrellaModule( + DummyNestedModule(3), DummyShardedModule() + ) + self.umbrella3 = DummyUmbrellaModule( + DummyNestedModule(4), DummyNestedModule(5) + ) + self.umbrella4 = DummyUmbrellaModule( + DummyNestedModule(6), DummyNestedModule(7) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return ( + # umbrella2 and umbrella4 are not directly + # called in this forward function + self.nested(x) + + self.umbrella1(x) + + self.umbrella2.m1(x) + + self.umbrella2.m2(x) + + self.umbrella3(x) + + self.umbrella4.m1(x) + + self.umbrella4.m2(x) + ) + + return MyModel() + + def test_get_leaf_module_names(self) -> None: + model = self._generate_sharded_model() + leaf_modules = _get_leaf_module_names(model) + self.assertSetEqual( + set(leaf_modules), # umbrella1.m2 and umbrella2.m2 are `ShardedModule`s + {"nested", "umbrella1.m1", "umbrella2.m1", "umbrella3", "umbrella4"}, + ) + + def test_top_level_tracer(self) -> None: + model = self._generate_sharded_model() + concrete_args = {} + tracer = Tracer( + leaf_modules=_get_leaf_module_names(model), extend_leaf_fqn=True + ) + graph = tracer.trace(model, concrete_args=concrete_args) + targets = {node.target for node in graph.nodes if node.op == "call_module"} + self.assertSetEqual( + targets, + { + "nested", + "umbrella1.m1", + "umbrella1.m2", + "umbrella2.m1", + "umbrella2.m2", + "umbrella3", + "umbrella4.m1", # umbrella4 is not called in model.forward + "umbrella4.m2", # so umbrella4 is not a leaf module + }, + ) diff --git a/torchrec/distributed/train_pipeline/tracing.py b/torchrec/distributed/train_pipeline/tracing.py index efc0bf962..199d9f3f0 100644 --- a/torchrec/distributed/train_pipeline/tracing.py +++ b/torchrec/distributed/train_pipeline/tracing.py @@ -514,6 +514,29 @@ def _get_leaf_module_names(model: torch.nn.Module) -> List[str]: This is a shallow FX trace that only goes the minimum depth required to pipeline. Any sub-module who does not contain a ShardedModule would be considered as a leaf module unless explicitly tagged as `_is_pytorch_fx_traceable = True`. + + disclaimer: + the algorithm is based on "named_modules()" API from torch.nn.Module, there + could be corner cases that a returned "leaf_module" is not actually called in the + forward pass. for example the umbrella_module would be considered as the top-level + leaf module but it actually won't appear in a fx-traced graph. + + ``` + # the main_model's hierarchy looks like below: + main_model + - sharded_module + - umbrella_module + - actual_leaf_module_1 + - actual_leaf_module_2 + + # and the main_model's forward is something like: + def forward(self, x1, x2, x3): + emb1 = self.sharded_module(x1) + emb2 = self.umbrella_module.actual_leaf_module_1(x2) + emb3 = self.umbrella_module.actual_leaf_module_2(x3) + return emb1 + emb2 + emb3 + ``` + """ def _get_leaf_module_names_helper( @@ -573,9 +596,22 @@ class Tracer(torch.fx.Tracer): # remove this line. proxy_buffer_attributes = False - def __init__(self, leaf_modules: Optional[List[str]] = None) -> None: + def __init__( + self, leaf_modules: Optional[List[str]] = None, extend_leaf_fqn: bool = False + ) -> None: + """ + Initializes the Tracer for FX tracing with custom leaf module handling. + + Args: + leaf_modules: Optional list of fully qualified names (FQNs) of modules to treat + as leaf modules during tracing. If None, defaults to an empty list. + extend_leaf_fqn: If True, treats any module whose FQN starts with a leaf module + FQN as a leaf module (includes submodules). If False, only exact matches + are considered leaf modules. Defaults to False. + """ super().__init__() self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else [] + self._extend_leaf_fqn = extend_leaf_fqn def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: if ( @@ -585,4 +621,18 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool or isinstance(m, FSDP2) ): return True + if self._extend_leaf_fqn: + if self.is_extended_leaf_modules(m, module_qualified_name): + return True return super().is_leaf_module(m, module_qualified_name) + + def is_extended_leaf_modules( + self, m: torch.nn.Module, module_qualified_name: str + ) -> bool: + for leaf_module in self._leaf_modules: + if module_qualified_name.startswith(leaf_module): + # in a corner case that the fqn == 'main_model.leaf_module.submod' + # we should consider this fqn also a leaf_module + if module_qualified_name[len(leaf_module)] == ".": + return True + return False