55# LICENSE file in the root directory of this source tree.
66
77import contextlib
8+ from pathlib import Path
89from typing import Callable , List , Optional
910
1011import torch
2122from torchtitan .tools .logging import logger
2223
2324
25+ def _dump_gm (dump_folder : str | None , gm : torch .fx .GraphModule , name : str ) -> None :
26+ # TODO: make the dump rank configurable
27+ if not dump_folder or torch .distributed .get_rank () != 0 :
28+ return
29+
30+ output_path = Path (dump_folder ) / "compiler" / f"{ name } .txt"
31+ output_path .parent .mkdir (parents = True , exist_ok = True )
32+ output_path .write_text (gm .print_readable (print_output = False ))
33+
34+
2435def export_joint (
25- model , args , kwargs = None
36+ model , args , kwargs = None , dump_folder : str | None = None
2637) -> tuple [JointWithDescriptors , TracingContext ]:
2738 if kwargs is None :
2839 kwargs = {}
@@ -35,8 +46,10 @@ def export_joint(
3546 torch .fx .traceback .preserve_node_meta (),
3647 ):
3748 gm = dynamo_graph_capture_for_export (model )(* args , ** kwargs )
38- logger .info ("Dynamo gm:" )
39- logger .info (gm .print_readable (print_output = False ))
49+ logger .debug ("Dynamo gm:" )
50+ logger .debug (gm .print_readable (print_output = False ))
51+ _dump_gm (dump_folder , gm , "dynamo_gm" )
52+
4053 tracing_context = gm .meta ["tracing_context" ]
4154
4255 with tracing (tracing_context ):
@@ -68,6 +81,7 @@ def joint_graph_builder(
6881 fw_compiler : Optional [Callable ] = None ,
6982 bw_compiler : Optional [Callable ] = None ,
7083 joint_custom_pass : Optional [Callable ] = None ,
84+ dump_folder : str | None = None ,
7185):
7286 """
7387 Build a joint forward-backward graph for the model with optional custom compilers.
@@ -79,16 +93,17 @@ def joint_graph_builder(
7993 fw_compiler: Optional custom forward compiler function
8094 bw_compiler: Optional custom backward compiler function
8195 joint_custom_pass: Optional custom pass to run on the joint graph
96+ dump_folder: Optional folder to dump the graph to
8297 """
8398 assert isinstance (model_args , tuple )
84- for arg in model_args :
85- assert isinstance (arg , DTensor )
99+ for idx , arg in enumerate ( model_args ) :
100+ assert isinstance (arg , DTensor ), f"Argument { idx } is of type { type ( arg ) } "
86101
87102 # get joint graph
88103 (
89104 joint_with_descriptors ,
90105 tracing_context ,
91- ) = export_joint (model , model_args , model_kwargs )
106+ ) = export_joint (model , model_args , model_kwargs , dump_folder = dump_folder )
92107
93108 # Optional validation
94109 if joint_custom_pass is not None :
@@ -179,6 +194,7 @@ def compiler(
179194 gm : torch .fx .GraphModule ,
180195 example_inputs ,
181196 passes : List [Callable ] = None ,
197+ dump_folder : str | None = None ,
182198):
183199 """
184200 Compile a graph module by applying a sequence of compiler passes.
@@ -190,23 +206,28 @@ def compiler(
190206 passes: List of compiler pass functions to apply. Each function should take
191207 (gm, example_inputs) and return a transformed gm. If None, uses
192208 DEFAULT_COMPILER_PASSES.
209+ dump_folder: Optional folder to dump the graph to
193210 """
194211 if passes is None :
195212 passes = DEFAULT_COMPILER_PASSES
196213
197- logger .info (f"{ name } before compiler:" )
198- logger .info (gm .print_readable (print_output = False ))
214+ logger .debug (f"{ name } before compiler:" )
215+ logger .debug (gm .print_readable (print_output = False ))
216+ _dump_gm (dump_folder , gm , f"{ name } _before_compiler" )
199217
200218 for pass_fn in passes :
201219 logger .info (f"Applying pass: { pass_fn .__name__ } " )
202220 gm = pass_fn (gm , example_inputs )
203221
204- logger .info (f"{ name } after compiler:" )
205- logger .info (gm .print_readable (print_output = False ))
222+ logger .debug (f"{ name } after compiler:" )
223+ logger .debug (gm .print_readable (print_output = False ))
224+ _dump_gm (dump_folder , gm , f"{ name } _after_compiler" )
206225 return gm
207226
208227
209- def make_compiler_with_passes (passes : List [Callable ] = None ):
228+ def make_compiler_with_passes (
229+ passes : List [Callable ] = None , dump_folder : str | None = None
230+ ):
210231 """
211232 Create forward and backward compilers with specified passes.
212233
@@ -218,10 +239,14 @@ def make_compiler_with_passes(passes: List[Callable] = None):
218239 """
219240
220241 def fw_compiler (gm : torch .fx .GraphModule , example_inputs ) -> None :
221- return compiler ("fwd_gm" , gm , example_inputs , passes = passes )
242+ return compiler (
243+ "fwd_gm" , gm , example_inputs , passes = passes , dump_folder = dump_folder
244+ )
222245
223246 def bw_compiler (gm : torch .fx .GraphModule , example_inputs ) -> None :
224- return compiler ("bwd_gm" , gm , example_inputs , passes = passes )
247+ return compiler (
248+ "bwd_gm" , gm , example_inputs , passes = passes , dump_folder = dump_folder
249+ )
225250
226251 return fw_compiler , bw_compiler
227252
0 commit comments