Skip to content

Commit e7ee95a

Browse files
authored
[Compiler Toolkit] Make compiler toolkit work with checkpoint (#2030)
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * #2029 * __->__ #2030 The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. Also merge #2028 to this PR. Something wrong with ghstack.
1 parent ce1c0fc commit e7ee95a

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ def parallelize_deepseekv3(
8080
compiler_passes = get_compiler_passes_from_config(job_config)
8181

8282
# Create compilers with specified passes (defaults to no passes)
83-
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
83+
fw_compiler, bw_compiler = make_compiler_with_passes(
84+
compiler_passes, dump_folder=job_config.job.dump_folder
85+
)
8486

8587
# Create custom joint_graph_builder with deepseekv3-specific compilers
8688
deepseekv3_joint_graph_builder = functools.partial(
8789
joint_graph_builder,
8890
fw_compiler=fw_compiler,
8991
bw_compiler=bw_compiler,
9092
joint_custom_pass=validate_flex_attention_annotation,
93+
dump_folder=job_config.job.dump_folder,
9194
)
9295

9396
# TODO: CompiledModule should take sample input as well, so that we can

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8-
from typing import Callable, List, Optional
8+
from pathlib import Path
9+
from typing import Any, Callable, List, Optional
910

1011
import torch
1112
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
@@ -21,8 +22,18 @@
2122
from torchtitan.tools.logging import logger
2223

2324

25+
def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None:
26+
# TODO: make the dump rank configurable
27+
if not dump_folder or torch.distributed.get_rank() != 0:
28+
return
29+
30+
output_path = Path(dump_folder) / "compiler" / f"{name}.txt"
31+
output_path.parent.mkdir(parents=True, exist_ok=True)
32+
output_path.write_text(gm.print_readable(print_output=False))
33+
34+
2435
def export_joint(
25-
model, args, kwargs=None
36+
model, args, kwargs=None, dump_folder: str | None = None
2637
) -> tuple[JointWithDescriptors, TracingContext]:
2738
if kwargs is None:
2839
kwargs = {}
@@ -35,8 +46,10 @@ def export_joint(
3546
torch.fx.traceback.preserve_node_meta(),
3647
):
3748
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
38-
logger.info("Dynamo gm:")
39-
logger.info(gm.print_readable(print_output=False))
49+
logger.debug("Dynamo gm:")
50+
logger.debug(gm.print_readable(print_output=False))
51+
_dump_gm(dump_folder, gm, "dynamo_gm")
52+
4053
tracing_context = gm.meta["tracing_context"]
4154

4255
with tracing(tracing_context):
@@ -68,6 +81,7 @@ def joint_graph_builder(
6881
fw_compiler: Optional[Callable] = None,
6982
bw_compiler: Optional[Callable] = None,
7083
joint_custom_pass: Optional[Callable] = None,
84+
dump_folder: str | None = None,
7185
):
7286
"""
7387
Build a joint forward-backward graph for the model with optional custom compilers.
@@ -79,16 +93,17 @@ def joint_graph_builder(
7993
fw_compiler: Optional custom forward compiler function
8094
bw_compiler: Optional custom backward compiler function
8195
joint_custom_pass: Optional custom pass to run on the joint graph
96+
dump_folder: Optional folder to dump the graph to
8297
"""
8398
assert isinstance(model_args, tuple)
84-
for arg in model_args:
85-
assert isinstance(arg, DTensor)
99+
for idx, arg in enumerate(model_args):
100+
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"
86101

87102
# get joint graph
88103
(
89104
joint_with_descriptors,
90105
tracing_context,
91-
) = export_joint(model, model_args, model_kwargs)
106+
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
92107

93108
# Optional validation
94109
if joint_custom_pass is not None:
@@ -153,6 +168,18 @@ def __delattr__(self, name: str) -> None:
153168
else:
154169
super().__delattr__(name)
155170

171+
def state_dict(self, *args, **kwargs) -> Any:
172+
return self.inner.state_dict(*args, **kwargs)
173+
174+
def load_state_dict(self, *args, **kwargs) -> Any:
175+
return self.inner.load_state_dict(*args, **kwargs)
176+
177+
def name_parameters(self, *args, **kwargs) -> Any:
178+
return self.inner.named_parameters(*args, **kwargs)
179+
180+
def parameters(self, *args, **kwargs) -> Any:
181+
return self.inner.parameters(*args, **kwargs)
182+
156183
def forward(self, *args, **kwargs):
157184
assert "forward" not in self._overrides, "forward cannot be overridden"
158185

@@ -179,6 +206,7 @@ def compiler(
179206
gm: torch.fx.GraphModule,
180207
example_inputs,
181208
passes: List[Callable] = None,
209+
dump_folder: str | None = None,
182210
):
183211
"""
184212
Compile a graph module by applying a sequence of compiler passes.
@@ -190,23 +218,28 @@ def compiler(
190218
passes: List of compiler pass functions to apply. Each function should take
191219
(gm, example_inputs) and return a transformed gm. If None, uses
192220
DEFAULT_COMPILER_PASSES.
221+
dump_folder: Optional folder to dump the graph to
193222
"""
194223
if passes is None:
195224
passes = DEFAULT_COMPILER_PASSES
196225

197-
logger.info(f"{name} before compiler:")
198-
logger.info(gm.print_readable(print_output=False))
226+
logger.debug(f"{name} before compiler:")
227+
logger.debug(gm.print_readable(print_output=False))
228+
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
199229

200230
for pass_fn in passes:
201231
logger.info(f"Applying pass: {pass_fn.__name__}")
202232
gm = pass_fn(gm, example_inputs)
203233

204-
logger.info(f"{name} after compiler:")
205-
logger.info(gm.print_readable(print_output=False))
234+
logger.debug(f"{name} after compiler:")
235+
logger.debug(gm.print_readable(print_output=False))
236+
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
206237
return gm
207238

208239

209-
def make_compiler_with_passes(passes: List[Callable] = None):
240+
def make_compiler_with_passes(
241+
passes: List[Callable] = None, dump_folder: str | None = None
242+
):
210243
"""
211244
Create forward and backward compilers with specified passes.
212245
@@ -218,10 +251,14 @@ def make_compiler_with_passes(passes: List[Callable] = None):
218251
"""
219252

220253
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
221-
return compiler("fwd_gm", gm, example_inputs, passes=passes)
254+
return compiler(
255+
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
256+
)
222257

223258
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
224-
return compiler("bwd_gm", gm, example_inputs, passes=passes)
259+
return compiler(
260+
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
261+
)
225262

226263
return fw_compiler, bw_compiler
227264

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def parallelize_llama(
6767
compiler_passes = get_compiler_passes_from_config(job_config)
6868

6969
# Create compilers with specified passes (defaults to no passes)
70-
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
70+
fw_compiler, bw_compiler = make_compiler_with_passes(
71+
compiler_passes, dump_folder=job_config.job.dump_folder
72+
)
7173

7274
# Create custom joint_graph_builder with llama-specific compilers and validation
7375
llama_joint_graph_builder = functools.partial(
7476
joint_graph_builder,
7577
fw_compiler=fw_compiler,
7678
bw_compiler=bw_compiler,
7779
joint_custom_pass=validate_flex_attention_annotation,
80+
dump_folder=job_config.job.dump_folder,
7881
)
7982

8083
# TODO: CompiledModule should take sample input as well, so that we can

0 commit comments

Comments
 (0)