Skip to content

Commit 43891bc

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
fix the fx tracer leaf module logic (#3487)
Summary: # context * in response to a TorchRec User Group [post](https://fb.workplace.com/groups/970281557043698/permalink/1984937168911460/) * in torchrec.train_pipeline, the modified fx tracer the algorithm is based on "named_modules()" API from torch.nn.Module, there could be corner cases that a returned "leaf_module" is not actually called in the forward pass. for example the umbrella_module would be considered as the top-level leaf module but it actually won't appear in a fx-traced graph. # example * traced graph ``` opcode name target args kwargs ------------- ------------ ----------------------- ---------------------------- -------- placeholder x x () {} call_module nested nested (x,) {} call_module umbrella1_m1 umbrella1.m1 (x,) {} call_module umbrella1_m2 umbrella1.m2 (x,) {} call_function add <built-in function add> (umbrella1_m1, umbrella1_m2) {} call_function add_1 <built-in function add> (nested, add) {} call_module umbrella2_m1 umbrella2.m1 (x,) {} call_function add_2 <built-in function add> (add_1, umbrella2_m1) {} call_module umbrella2_m2 umbrella2.m2 (x,) {} call_function add_3 <built-in function add> (add_2, umbrella2_m2) {} call_module umbrella3 umbrella3 (x,) {} call_function add_4 <built-in function add> (add_3, umbrella3) {} call_module umbrella4_m1 umbrella4.m1 (x,) {} call_function add_5 <built-in function add> (add_4, umbrella4_m1) {} call_module umbrella4_m2 umbrella4.m2 (x,) {} call_function add_6 <built-in function add> (add_5, umbrella4_m2) {} output output output (add_6,) {} ``` Differential Revision: D80858379
1 parent 3a6cf2e commit 43891bc

File tree

2 files changed

+176
-2
lines changed

2 files changed

+176
-2
lines changed

torchrec/distributed/train_pipeline/tests/test_tracing.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,23 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import List
11+
from typing import List, Optional
1212
from unittest.mock import MagicMock
1313

1414
import parameterized
1515

1616
import torch
17+
from torch import nn
1718
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
1819
from torchrec.distributed.train_pipeline.tracing import (
20+
_get_leaf_module_names,
1921
ArgInfo,
2022
ArgInfoStepFactory,
2123
CallArgs,
2224
NodeArgsHelper,
25+
Tracer,
2326
)
27+
from torchrec.distributed.types import NullShardedModuleContext, ShardedModule
2428
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
2529

2630

@@ -110,3 +114,123 @@ def test_get_node_args_helper_call_module_kjt(self) -> None:
110114

111115
# Weights is call_module node, so we should only find 2 args unmodified
112116
self.assertEqual(num_found, len(kjt_args) - 1)
117+
118+
119+
class DummyShardedModule(
120+
ShardedModule[torch.Tensor, torch.Tensor, torch.Tensor, NullShardedModuleContext]
121+
):
122+
def __init__(self, alpha: float = 1) -> None:
123+
super().__init__()
124+
self.alpha = alpha
125+
126+
# pyre-ignore
127+
def forward(self, x: torch.Tensor) -> torch.Tensor:
128+
return self.alpha * x
129+
130+
# pyre-ignore
131+
def compute(self) -> torch.Tensor:
132+
return torch.empty(0)
133+
134+
def create_context(self) -> NullShardedModuleContext:
135+
return NullShardedModuleContext()
136+
137+
# pyre-ignore
138+
def input_dist(self, ctx: NullShardedModuleContext):
139+
pass
140+
141+
# pyre-ignore
142+
def output_dist(self):
143+
pass
144+
145+
# pyre-ignore
146+
def unsharded_module_type(self):
147+
pass
148+
149+
150+
class DummyUmbrellaModule(nn.Module):
151+
def __init__(self, m1: nn.Module, m2: nn.Module) -> None:
152+
super().__init__()
153+
self.m1 = m1
154+
self.m2 = m2
155+
156+
def forward(self, x: torch.Tensor) -> torch.Tensor:
157+
return self.m1(x) + self.m2(x)
158+
159+
160+
class DummyNestedModule(nn.Module):
161+
def __init__(self, layer: int = 0) -> None:
162+
super().__init__()
163+
self.layer = layer
164+
self.inner: Optional[nn.Module] = (
165+
DummyNestedModule(layer - 1) if layer > 0 else None
166+
)
167+
168+
def forward(self, x: torch.Tensor) -> torch.Tensor:
169+
inner = 0 if self.inner is None else self.inner(x)
170+
return inner + 10**self.layer
171+
172+
173+
class TestFxTracer(unittest.TestCase):
174+
@classmethod
175+
def _generate_sharded_model(cls) -> nn.Module:
176+
class MyModel(nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
self.nested = DummyNestedModule(3)
180+
self.umbrella1 = DummyUmbrellaModule(
181+
DummyNestedModule(2), DummyShardedModule()
182+
)
183+
self.umbrella2 = DummyUmbrellaModule(
184+
DummyNestedModule(3), DummyShardedModule()
185+
)
186+
self.umbrella3 = DummyUmbrellaModule(
187+
DummyNestedModule(4), DummyNestedModule(5)
188+
)
189+
self.umbrella4 = DummyUmbrellaModule(
190+
DummyNestedModule(6), DummyNestedModule(7)
191+
)
192+
193+
def forward(self, x: torch.Tensor) -> torch.Tensor:
194+
return (
195+
# umbrella2 and umbrella4 are not directly
196+
# called in this forward function
197+
self.nested(x)
198+
+ self.umbrella1(x)
199+
+ self.umbrella2.m1(x)
200+
+ self.umbrella2.m2(x)
201+
+ self.umbrella3(x)
202+
+ self.umbrella4.m1(x)
203+
+ self.umbrella4.m2(x)
204+
)
205+
206+
return MyModel()
207+
208+
def test_get_leaf_module_names(self) -> None:
209+
model = self._generate_sharded_model()
210+
leaf_modules = _get_leaf_module_names(model)
211+
self.assertSetEqual(
212+
set(leaf_modules), # umbrella1.m2 and umbrella2.m2 are `ShardedModule`s
213+
{"nested", "umbrella1.m1", "umbrella2.m1", "umbrella3", "umbrella4"},
214+
)
215+
216+
def test_top_level_tracer(self) -> None:
217+
model = self._generate_sharded_model()
218+
concrete_args = {}
219+
tracer = Tracer(
220+
leaf_modules=_get_leaf_module_names(model), extend_leaf_fqn=True
221+
)
222+
graph = tracer.trace(model, concrete_args=concrete_args)
223+
targets = {node.target for node in graph.nodes if node.op == "call_module"}
224+
self.assertSetEqual(
225+
targets,
226+
{
227+
"nested",
228+
"umbrella1.m1",
229+
"umbrella1.m2",
230+
"umbrella2.m1",
231+
"umbrella2.m2",
232+
"umbrella3",
233+
"umbrella4.m1", # umbrella4 is not called in model.forward
234+
"umbrella4.m2", # so umbrella4 is not a leaf module
235+
},
236+
)

torchrec/distributed/train_pipeline/tracing.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,29 @@ def _get_leaf_module_names(model: torch.nn.Module) -> List[str]:
514514
This is a shallow FX trace that only goes the minimum depth required to pipeline.
515515
Any sub-module who does not contain a ShardedModule would be considered as a leaf
516516
module unless explicitly tagged as `_is_pytorch_fx_traceable = True`.
517+
518+
disclaimer:
519+
the algorithm is based on "named_modules()" API from torch.nn.Module, there
520+
could be corner cases that a returned "leaf_module" is not actually called in the
521+
forward pass. for example the umbrella_module would be considered as the top-level
522+
leaf module but it actually won't appear in a fx-traced graph.
523+
524+
```
525+
# the main_model's hierarchy looks like below:
526+
main_model
527+
- sharded_module
528+
- umbrella_module
529+
- actual_leaf_module_1
530+
- actual_leaf_module_2
531+
532+
# and the main_model's forward is something like:
533+
def forward(self, x1, x2, x3):
534+
emb1 = self.sharded_module(x1)
535+
emb2 = self.umbrella_module.actual_leaf_module_1(x2)
536+
emb3 = self.umbrella_module.actual_leaf_module_2(x3)
537+
return emb1 + emb2 + emb3
538+
```
539+
517540
"""
518541

519542
def _get_leaf_module_names_helper(
@@ -573,9 +596,22 @@ class Tracer(torch.fx.Tracer):
573596
# remove this line.
574597
proxy_buffer_attributes = False
575598

576-
def __init__(self, leaf_modules: Optional[List[str]] = None) -> None:
599+
def __init__(
600+
self, leaf_modules: Optional[List[str]] = None, extend_leaf_fqn: bool = False
601+
) -> None:
602+
"""
603+
Initializes the Tracer for FX tracing with custom leaf module handling.
604+
605+
Args:
606+
leaf_modules: Optional list of fully qualified names (FQNs) of modules to treat
607+
as leaf modules during tracing. If None, defaults to an empty list.
608+
extend_leaf_fqn: If True, treats any module whose FQN starts with a leaf module
609+
FQN as a leaf module (includes submodules). If False, only exact matches
610+
are considered leaf modules. Defaults to False.
611+
"""
577612
super().__init__()
578613
self._leaf_modules: List[str] = leaf_modules if leaf_modules is not None else []
614+
self._extend_leaf_fqn = extend_leaf_fqn
579615

580616
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
581617
if (
@@ -585,4 +621,18 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool
585621
or isinstance(m, FSDP2)
586622
):
587623
return True
624+
if self._extend_leaf_fqn:
625+
if self.is_extended_leaf_modules(m, module_qualified_name):
626+
return True
588627
return super().is_leaf_module(m, module_qualified_name)
628+
629+
def is_extended_leaf_modules(
630+
self, m: torch.nn.Module, module_qualified_name: str
631+
) -> bool:
632+
for leaf_module in self._leaf_modules:
633+
if module_qualified_name.startswith(leaf_module):
634+
# in a corner case that the fqn == 'main_model.leaf_module.submod'
635+
# we should consider this fqn also a leaf_module
636+
if module_qualified_name[len(leaf_module)] == ".":
637+
return True
638+
return False

0 commit comments

Comments
 (0)