|
8 | 8 | # pyre-strict |
9 | 9 |
|
10 | 10 | import unittest |
11 | | -from typing import List |
| 11 | +from typing import List, Optional |
12 | 12 | from unittest.mock import MagicMock |
13 | 13 |
|
14 | 14 | import parameterized |
15 | 15 |
|
16 | 16 | import torch |
| 17 | +from torch import nn |
17 | 18 | from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext |
18 | 19 | from torchrec.distributed.train_pipeline.tracing import ( |
| 20 | + _get_leaf_module_names, |
19 | 21 | ArgInfo, |
20 | 22 | ArgInfoStepFactory, |
21 | 23 | CallArgs, |
22 | 24 | NodeArgsHelper, |
| 25 | + Tracer, |
23 | 26 | ) |
| 27 | +from torchrec.distributed.types import NullShardedModuleContext, ShardedModule |
24 | 28 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
25 | 29 |
|
26 | 30 |
|
@@ -110,3 +114,123 @@ def test_get_node_args_helper_call_module_kjt(self) -> None: |
110 | 114 |
|
111 | 115 | # Weights is call_module node, so we should only find 2 args unmodified |
112 | 116 | self.assertEqual(num_found, len(kjt_args) - 1) |
| 117 | + |
| 118 | + |
| 119 | +class DummyShardedModule( |
| 120 | + ShardedModule[torch.Tensor, torch.Tensor, torch.Tensor, NullShardedModuleContext] |
| 121 | +): |
| 122 | + def __init__(self, alpha: float = 1) -> None: |
| 123 | + super().__init__() |
| 124 | + self.alpha = alpha |
| 125 | + |
| 126 | + # pyre-ignore |
| 127 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 128 | + return self.alpha * x |
| 129 | + |
| 130 | + # pyre-ignore |
| 131 | + def compute(self) -> torch.Tensor: |
| 132 | + return torch.empty(0) |
| 133 | + |
| 134 | + def create_context(self) -> NullShardedModuleContext: |
| 135 | + return NullShardedModuleContext() |
| 136 | + |
| 137 | + # pyre-ignore |
| 138 | + def input_dist(self, ctx: NullShardedModuleContext): |
| 139 | + pass |
| 140 | + |
| 141 | + # pyre-ignore |
| 142 | + def output_dist(self): |
| 143 | + pass |
| 144 | + |
| 145 | + # pyre-ignore |
| 146 | + def unsharded_module_type(self): |
| 147 | + pass |
| 148 | + |
| 149 | + |
| 150 | +class DummyUmbrellaModule(nn.Module): |
| 151 | + def __init__(self, m1: nn.Module, m2: nn.Module) -> None: |
| 152 | + super().__init__() |
| 153 | + self.m1 = m1 |
| 154 | + self.m2 = m2 |
| 155 | + |
| 156 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 157 | + return self.m1(x) + self.m2(x) |
| 158 | + |
| 159 | + |
| 160 | +class DummyNestedModule(nn.Module): |
| 161 | + def __init__(self, layer: int = 0) -> None: |
| 162 | + super().__init__() |
| 163 | + self.layer = layer |
| 164 | + self.inner: Optional[nn.Module] = ( |
| 165 | + DummyNestedModule(layer - 1) if layer > 0 else None |
| 166 | + ) |
| 167 | + |
| 168 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 169 | + inner = 0 if self.inner is None else self.inner(x) |
| 170 | + return inner + 10**self.layer |
| 171 | + |
| 172 | + |
| 173 | +class TestFxTracer(unittest.TestCase): |
| 174 | + @classmethod |
| 175 | + def _generate_sharded_model(cls) -> nn.Module: |
| 176 | + class MyModel(nn.Module): |
| 177 | + def __init__(self): |
| 178 | + super().__init__() |
| 179 | + self.nested = DummyNestedModule(3) |
| 180 | + self.umbrella1 = DummyUmbrellaModule( |
| 181 | + DummyNestedModule(2), DummyShardedModule() |
| 182 | + ) |
| 183 | + self.umbrella2 = DummyUmbrellaModule( |
| 184 | + DummyNestedModule(3), DummyShardedModule() |
| 185 | + ) |
| 186 | + self.umbrella3 = DummyUmbrellaModule( |
| 187 | + DummyNestedModule(4), DummyNestedModule(5) |
| 188 | + ) |
| 189 | + self.umbrella4 = DummyUmbrellaModule( |
| 190 | + DummyNestedModule(6), DummyNestedModule(7) |
| 191 | + ) |
| 192 | + |
| 193 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 194 | + return ( |
| 195 | + # umbrella2 and umbrella4 are not directly |
| 196 | + # called in this forward function |
| 197 | + self.nested(x) |
| 198 | + + self.umbrella1(x) |
| 199 | + + self.umbrella2.m1(x) |
| 200 | + + self.umbrella2.m2(x) |
| 201 | + + self.umbrella3(x) |
| 202 | + + self.umbrella4.m1(x) |
| 203 | + + self.umbrella4.m2(x) |
| 204 | + ) |
| 205 | + |
| 206 | + return MyModel() |
| 207 | + |
| 208 | + def test_get_leaf_module_names(self) -> None: |
| 209 | + model = self._generate_sharded_model() |
| 210 | + leaf_modules = _get_leaf_module_names(model) |
| 211 | + self.assertSetEqual( |
| 212 | + set(leaf_modules), # umbrella1.m2 and umbrella2.m2 are `ShardedModule`s |
| 213 | + {"nested", "umbrella1.m1", "umbrella2.m1", "umbrella3", "umbrella4"}, |
| 214 | + ) |
| 215 | + |
| 216 | + def test_top_level_tracer(self) -> None: |
| 217 | + model = self._generate_sharded_model() |
| 218 | + concrete_args = {} |
| 219 | + tracer = Tracer( |
| 220 | + leaf_modules=_get_leaf_module_names(model), extend_leaf_fqn=True |
| 221 | + ) |
| 222 | + graph = tracer.trace(model, concrete_args=concrete_args) |
| 223 | + targets = {node.target for node in graph.nodes if node.op == "call_module"} |
| 224 | + self.assertSetEqual( |
| 225 | + targets, |
| 226 | + { |
| 227 | + "nested", |
| 228 | + "umbrella1.m1", |
| 229 | + "umbrella1.m2", |
| 230 | + "umbrella2.m1", |
| 231 | + "umbrella2.m2", |
| 232 | + "umbrella3", |
| 233 | + "umbrella4.m1", # umbrella4 is not called in model.forward |
| 234 | + "umbrella4.m2", # so umbrella4 is not a leaf module |
| 235 | + }, |
| 236 | + ) |
0 commit comments