Skip to content

Commit 2d4561b

Browse files
committed
Improve compiler toolkit logging
Summary: 1. When an input is not a DTensor, print out its index and type. 2. Graph dump spams the initialization log, make it to debug level and output the graphs to the dump folder. ghstack-source-id: fb2239f Pull-Request: #2028
1 parent 7d9a266 commit 2d4561b

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ def parallelize_deepseekv3(
8080
compiler_passes = get_compiler_passes_from_config(job_config)
8181

8282
# Create compilers with specified passes (defaults to no passes)
83-
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
83+
fw_compiler, bw_compiler = make_compiler_with_passes(
84+
compiler_passes, dump_folder=job_config.job.dump_folder
85+
)
8486

8587
# Create custom joint_graph_builder with deepseekv3-specific compilers
8688
deepseekv3_joint_graph_builder = functools.partial(
8789
joint_graph_builder,
8890
fw_compiler=fw_compiler,
8991
bw_compiler=bw_compiler,
9092
joint_custom_pass=validate_flex_attention_annotation,
93+
dump_folder=job_config.job.dump_folder,
9194
)
9295

9396
# TODO: CompiledModule should take sample input as well, so that we can

torchtitan/experiments/compiler_toolkit/graph_utils.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8+
from pathlib import Path
89
from typing import Callable, List, Optional
910

1011
import torch
@@ -21,8 +22,18 @@
2122
from torchtitan.tools.logging import logger
2223

2324

25+
def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None:
26+
# TODO: make the dump rank configurable
27+
if not dump_folder or torch.distributed.get_rank() != 0:
28+
return
29+
30+
output_path = Path(dump_folder) / "compiler" / f"{name}.txt"
31+
output_path.parent.mkdir(parents=True, exist_ok=True)
32+
output_path.write_text(gm.print_readable(print_output=False))
33+
34+
2435
def export_joint(
25-
model, args, kwargs=None
36+
model, args, kwargs=None, dump_folder: str | None = None
2637
) -> tuple[JointWithDescriptors, TracingContext]:
2738
if kwargs is None:
2839
kwargs = {}
@@ -35,8 +46,10 @@ def export_joint(
3546
torch.fx.traceback.preserve_node_meta(),
3647
):
3748
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
38-
logger.info("Dynamo gm:")
39-
logger.info(gm.print_readable(print_output=False))
49+
logger.debug("Dynamo gm:")
50+
logger.debug(gm.print_readable(print_output=False))
51+
_dump_gm(dump_folder, gm, "dynamo_gm")
52+
4053
tracing_context = gm.meta["tracing_context"]
4154

4255
with tracing(tracing_context):
@@ -68,6 +81,7 @@ def joint_graph_builder(
6881
fw_compiler: Optional[Callable] = None,
6982
bw_compiler: Optional[Callable] = None,
7083
joint_custom_pass: Optional[Callable] = None,
84+
dump_folder: str | None = None,
7185
):
7286
"""
7387
Build a joint forward-backward graph for the model with optional custom compilers.
@@ -79,16 +93,17 @@ def joint_graph_builder(
7993
fw_compiler: Optional custom forward compiler function
8094
bw_compiler: Optional custom backward compiler function
8195
joint_custom_pass: Optional custom pass to run on the joint graph
96+
dump_folder: Optional folder to dump the graph to
8297
"""
8398
assert isinstance(model_args, tuple)
84-
for arg in model_args:
85-
assert isinstance(arg, DTensor)
99+
for idx, arg in enumerate(model_args):
100+
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"
86101

87102
# get joint graph
88103
(
89104
joint_with_descriptors,
90105
tracing_context,
91-
) = export_joint(model, model_args, model_kwargs)
106+
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)
92107

93108
# Optional validation
94109
if joint_custom_pass is not None:
@@ -179,6 +194,7 @@ def compiler(
179194
gm: torch.fx.GraphModule,
180195
example_inputs,
181196
passes: List[Callable] = None,
197+
dump_folder: str | None = None,
182198
):
183199
"""
184200
Compile a graph module by applying a sequence of compiler passes.
@@ -194,19 +210,23 @@ def compiler(
194210
if passes is None:
195211
passes = DEFAULT_COMPILER_PASSES
196212

197-
logger.info(f"{name} before compiler:")
198-
logger.info(gm.print_readable(print_output=False))
213+
logger.debug(f"{name} before compiler:")
214+
logger.debug(gm.print_readable(print_output=False))
215+
_dump_gm(dump_folder, gm, f"{name}_before_compiler")
199216

200217
for pass_fn in passes:
201218
logger.info(f"Applying pass: {pass_fn.__name__}")
202219
gm = pass_fn(gm, example_inputs)
203220

204-
logger.info(f"{name} after compiler:")
205-
logger.info(gm.print_readable(print_output=False))
221+
logger.debug(f"{name} after compiler:")
222+
logger.debug(gm.print_readable(print_output=False))
223+
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
206224
return gm
207225

208226

209-
def make_compiler_with_passes(passes: List[Callable] = None):
227+
def make_compiler_with_passes(
228+
passes: List[Callable] = None, dump_folder: str | None = None
229+
):
210230
"""
211231
Create forward and backward compilers with specified passes.
212232
@@ -218,10 +238,14 @@ def make_compiler_with_passes(passes: List[Callable] = None):
218238
"""
219239

220240
def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
221-
return compiler("fwd_gm", gm, example_inputs, passes=passes)
241+
return compiler(
242+
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
243+
)
222244

223245
def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
224-
return compiler("bwd_gm", gm, example_inputs, passes=passes)
246+
return compiler(
247+
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
248+
)
225249

226250
return fw_compiler, bw_compiler
227251

torchtitan/experiments/compiler_toolkit/llama3/parallelize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def parallelize_llama(
6767
compiler_passes = get_compiler_passes_from_config(job_config)
6868

6969
# Create compilers with specified passes (defaults to no passes)
70-
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
70+
fw_compiler, bw_compiler = make_compiler_with_passes(
71+
compiler_passes, dump_folder=job_config.job.dump_folder
72+
)
7173

7274
# Create custom joint_graph_builder with llama-specific compilers and validation
7375
llama_joint_graph_builder = functools.partial(
7476
joint_graph_builder,
7577
fw_compiler=fw_compiler,
7678
bw_compiler=bw_compiler,
7779
joint_custom_pass=validate_flex_attention_annotation,
80+
dump_folder=job_config.job.dump_folder,
7881
)
7982

8083
# TODO: CompiledModule should take sample input as well, so that we can

0 commit comments

Comments
 (0)