Skip to content

Commit 6b85551

Browse files
committed
Make compiler toolkit works with checkpoint
The current CompileModule will result in an "inner" prefix for everything. This PR fixes it by overloading the methods. ghstack-source-id: a16c514 Pull-Request: #2030
1 parent ce1c0fc commit 6b85551

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ def parallelize_deepseekv3(
8080
compiler_passes = get_compiler_passes_from_config(job_config)
8181

8282
# Create compilers with specified passes (defaults to no passes)
83-
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
83+
fw_compiler, bw_compiler = make_compiler_with_passes(
84+
compiler_passes, dump_folder=job_config.job.dump_folder
85+
)
8486

8587
# Create custom joint_graph_builder with deepseekv3-specific compilers
8688
deepseekv3_joint_graph_builder = functools.partial(
8789
joint_graph_builder,
8890
fw_compiler=fw_compiler,
8991
bw_compiler=bw_compiler,
9092
joint_custom_pass=validate_flex_attention_annotation,
93+
dump_folder=job_config.job.dump_folder,
9194
)
9295

9396
# TODO: CompiledModule should take sample input as well, so that we can

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8-
from typing import Callable, List, Optional
8+
from pathlib import Path
9+
from typing import Any, Callable, List, Optional
910

1011
import torch
1112
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
@@ -21,8 +22,18 @@
2122
from torchtitan.tools.logging import logger
2223

2324

25+
def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None:
26+
# TODO: make the dump rank configurable
27+
if not dump_folder or torch.distributed.get_rank() != 0:
28+
return
29+
30+
output_path = Path(dump_folder) / "compiler" / f"{name}.txt"
31+
output_path.parent.mkdir(parents=True, exist_ok=True)
32+
output_path.write_text(gm.print_readable(print_output=False))
33+
34+
2435
def export_joint(
25-
model, args, kwargs=None
36+
model, args, kwargs=None, dump_folder: str | None = None
2637
) -> tuple[JointWithDescriptors, TracingContext]:
2738
if kwargs is None:
2839
kwargs = {}
@@ -35,8 +46,10 @@ def export_joint(
3546
torch.fx.traceback.preserve_node_meta(),
3647
):
3748
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
38-
logger.info("Dynamo gm:")
39-
logger.info(gm.print_readable(print_output=False))
49+
logger.debug("Dynamo gm:")
50+
logger.debug(gm.print_readable(print_output=False))
51+
_dump_gm(dump_folder, gm, "dynamo_gm")
52+
4053
tracing_context = gm.meta["tracing_context"]
4154

4255
with tracing(tracing_context):
@@ -68,6 +81,7 @@ def joint_graph_builder(
6881
fw_compiler: Optional[Callable] = None,
6982
bw_compiler: Optional[Callable] = None,
7083
joint_custom_pass: Optional[Callable] = None,
84+
dump_folder: str | None = None,
7185
):
7286
"""
7387
Build a joint forward-backward graph for the model with optional custom compilers.
@@ -79,16 +93,17 @@ def joint_graph_builder(
7993
fw_compiler: Optional custom forward compiler function
8094
bw_compiler: Optional custom backward compiler function
8195
joint_custom_pass: Optional custom pass to run on the joint graph
96+
dump_folder: Optional folder to dump the graph to
8297
"""
8398
assert isinstance(model_args, tuple)
84-
for arg in model_args:
85-
assert isinstance(arg, DTensor)
99+
for idx, arg in enumerate(model_args):
100+
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"
86101

87102
# get joint graph
88103
(
89104
joint_with_descriptors,
90105
tracing_context,
91-
) = export_joint(model, model_args, model_kwargs)
106+
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
92107

93108
# Optional validation
94109
if joint_custom_pass is not None:
@@ -153,6 +168,18 @@ def __delattr__(self, name: str) -> None:
153168
else:
154169
super().__delattr__(name)
155170

171+
def state_dict(self, *args, **kwargs) -> Any:
172+
return self.inner.state_dict(*args, **kwargs)
173+
174+
def load_state_dict(self, *args, **kwargs) -> Any:
175+
return self.inner.load_state_dict(*args, **kwargs)
176+
177+
def name_parameters(self, *args, **kwargs) -> Any:
178+
return self.inner.named_parameters(*args, **kwargs)
179+
180+
def parameters(self, *args, **kwargs) -> Any:
181+
return self.inner.parameters(*args, **kwargs)
182+
156183
def forward(self, *args, **kwargs):
157184
assert "forward" not in self._overrides, "forward cannot be overridden"
158185

@@ -179,6 +206,7 @@ def compiler(
179206
gm: torch.fx.GraphModule,
180207
example_inputs,
181208
passes: List[Callable] = None,
209+
dump_folder: str | None = None,
182210
):
183211
"""
184212
Compile a graph module by applying a sequence of compiler passes.
@@ -190,23 +218,28 @@ def compiler(
190218
passes: List of compiler pass functions to apply. Each function should take
191219
(gm, example_inputs) and return a transformed gm. If None, uses
192220
DEFAULT_COMPILER_PASSES.
221+
dump_folder: Optional folder to dump the graph to
193222
"""
194223
if passes is None:
195224
passes = DEFAULT_COMPILER_PASSES
196225

197-
logger.info(f"{name} before compiler:")
198-
logger.info(gm.print_readable(print_output=False))
226+
logger.debug(f"{name} before compiler:")
227+
logger.debug(gm.print_readable(print_output=False))
228+
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
199229

200230
for pass_fn in passes:
201231
logger.info(f"Applying pass: {pass_fn.__name__}")
202232
gm = pass_fn(gm, example_inputs)
203233

204-
logger.info(f"{name} after compiler:")
205-
logger.info(gm.print_readable(print_output=False))
234+
logger.debug(f"{name} after compiler:")
235+
logger.debug(gm.print_readable(print_output=False))
236+
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
206237
return gm
207238

208239

209-
def make_compiler_with_passes(passes: List[Callable] = None):
240+
def make_compiler_with_passes(
241+
passes: List[Callable] = None, dump_folder: str | None = None
242+
):
210243
"""
211244
Create forward and backward compilers with specified passes.
212245
@@ -218,10 +251,14 @@ def make_compiler_with_passes(passes: List[Callable] = None):
218251
"""
219252

220253
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
221-
return compiler("fwd_gm", gm, example_inputs, passes=passes)
254+
return compiler(
255+
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
256+
)
222257

223258
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
224-
return compiler("bwd_gm", gm, example_inputs, passes=passes)
259+
return compiler(
260+
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
261+
)
225262

226263
return fw_compiler, bw_compiler
227264

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def parallelize_llama(
6767
compiler_passes = get_compiler_passes_from_config(job_config)
6868

6969
# Create compilers with specified passes (defaults to no passes)
70-
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
70+
fw_compiler, bw_compiler = make_compiler_with_passes(
71+
compiler_passes, dump_folder=job_config.job.dump_folder
72+
)
7173

7274
# Create custom joint_graph_builder with llama-specific compilers and validation
7375
llama_joint_graph_builder = functools.partial(
7476
joint_graph_builder,
7577
fw_compiler=fw_compiler,
7678
bw_compiler=bw_compiler,
7779
joint_custom_pass=validate_flex_attention_annotation,
80+
dump_folder=job_config.job.dump_folder,
7881
)
7982

8083
# TODO: CompiledModule should take sample input as well, so that we can

0 commit comments

Comments
 (0)