55# LICENSE file in the root directory of this source tree.
66
77import contextlib
8- from typing import Callable , List , Optional
8+ from pathlib import Path
9+ from typing import Any , Callable , List , Optional
910
1011import torch
1112from torch ._dynamo .functional_export import dynamo_graph_capture_for_export
2122from torchtitan .tools .logging import logger
2223
2324
25+ def _dump_gm (dump_folder : str | None , gm : torch .fx .GraphModule , name : str ) -> None :
26+ # TODO: make the dump rank configurable
27+ if not dump_folder or torch .distributed .get_rank () != 0 :
28+ return
29+
30+ output_path = Path (dump_folder ) / "compiler" / f"{ name } .txt"
31+ output_path .parent .mkdir (parents = True , exist_ok = True )
32+ output_path .write_text (gm .print_readable (print_output = False ))
33+
34+
2435def export_joint (
25- model , args , kwargs = None
36+ model , args , kwargs = None , dump_folder : str | None = None
2637) -> tuple [JointWithDescriptors , TracingContext ]:
2738 if kwargs is None :
2839 kwargs = {}
@@ -35,8 +46,10 @@ def export_joint(
3546 torch .fx .traceback .preserve_node_meta (),
3647 ):
3748 gm = dynamo_graph_capture_for_export (model )(* args , ** kwargs )
38- logger .info ("Dynamo gm:" )
39- logger .info (gm .print_readable (print_output = False ))
49+ logger .debug ("Dynamo gm:" )
50+ logger .debug (gm .print_readable (print_output = False ))
51+ _dump_gm (dump_folder , gm , "dynamo_gm" )
52+
4053 tracing_context = gm .meta ["tracing_context" ]
4154
4255 with tracing (tracing_context ):
@@ -68,6 +81,7 @@ def joint_graph_builder(
6881 fw_compiler : Optional [Callable ] = None ,
6982 bw_compiler : Optional [Callable ] = None ,
7083 joint_custom_pass : Optional [Callable ] = None ,
84+ dump_folder : str | None = None ,
7185):
7286 """
7387 Build a joint forward-backward graph for the model with optional custom compilers.
@@ -79,16 +93,17 @@ def joint_graph_builder(
7993 fw_compiler: Optional custom forward compiler function
8094 bw_compiler: Optional custom backward compiler function
8195 joint_custom_pass: Optional custom pass to run on the joint graph
96+ dump_folder: Optional folder to dump the graph to
8297 """
8398 assert isinstance (model_args , tuple )
84- for arg in model_args :
85- assert isinstance (arg , DTensor )
99+ for idx , arg in enumerate ( model_args ) :
100+ assert isinstance (arg , DTensor ), f"Argument { idx } is of type { type ( arg ) } "
86101
87102 # get joint graph
88103 (
89104 joint_with_descriptors ,
90105 tracing_context ,
91- ) = export_joint (model , model_args , model_kwargs )
106+ ) = export_joint (model , model_args , model_kwargs , dump_folder = dump_folder )
92107
93108 # Optional validation
94109 if joint_custom_pass is not None :
@@ -153,6 +168,18 @@ def __delattr__(self, name: str) -> None:
153168 else :
154169 super ().__delattr__ (name )
155170
171+ def state_dict (self , * args , ** kwargs ) -> Any :
172+ return self .inner .state_dict (* args , ** kwargs )
173+
174+ def load_state_dict (self , * args , ** kwargs ) -> Any :
175+ return self .inner .load_state_dict (* args , ** kwargs )
176+
177+ def name_parameters (self , * args , ** kwargs ) -> Any :
178+ return self .inner .named_parameters (* args , ** kwargs )
179+
180+ def parameters (self , * args , ** kwargs ) -> Any :
181+ return self .inner .parameters (* args , ** kwargs )
182+
156183 def forward (self , * args , ** kwargs ):
157184 assert "forward" not in self ._overrides , "forward cannot be overridden"
158185
@@ -179,6 +206,7 @@ def compiler(
179206 gm : torch .fx .GraphModule ,
180207 example_inputs ,
181208 passes : List [Callable ] = None ,
209+ dump_folder : str | None = None ,
182210):
183211 """
184212 Compile a graph module by applying a sequence of compiler passes.
@@ -190,23 +218,28 @@ def compiler(
190218 passes: List of compiler pass functions to apply. Each function should take
191219 (gm, example_inputs) and return a transformed gm. If None, uses
192220 DEFAULT_COMPILER_PASSES.
221+ dump_folder: Optional folder to dump the graph to
193222 """
194223 if passes is None :
195224 passes = DEFAULT_COMPILER_PASSES
196225
197- logger .info (f"{ name } before compiler:" )
198- logger .info (gm .print_readable (print_output = False ))
226+ logger .debug (f"{ name } before compiler:" )
227+ logger .debug (gm .print_readable (print_output = False ))
228+ _dump_gm (dump_folder , gm , f"{ name } _before_compiler" )
199229
200230 for pass_fn in passes :
201231 logger .info (f"Applying pass: { pass_fn .__name__ } " )
202232 gm = pass_fn (gm , example_inputs )
203233
204- logger .info (f"{ name } after compiler:" )
205- logger .info (gm .print_readable (print_output = False ))
234+ logger .debug (f"{ name } after compiler:" )
235+ logger .debug (gm .print_readable (print_output = False ))
236+ _dump_gm (dump_folder , gm , f"{ name } _after_compiler" )
206237 return gm
207238
208239
209- def make_compiler_with_passes (passes : List [Callable ] = None ):
240+ def make_compiler_with_passes (
241+ passes : List [Callable ] = None , dump_folder : str | None = None
242+ ):
210243 """
211244 Create forward and backward compilers with specified passes.
212245
@@ -218,10 +251,14 @@ def make_compiler_with_passes(passes: List[Callable] = None):
218251 """
219252
220253 def fw_compiler (gm : torch .fx .GraphModule , example_inputs ) -> None :
221- return compiler ("fwd_gm" , gm , example_inputs , passes = passes )
254+ return compiler (
255+ "fwd_gm" , gm , example_inputs , passes = passes , dump_folder = dump_folder
256+ )
222257
223258 def bw_compiler (gm : torch .fx .GraphModule , example_inputs ) -> None :
224- return compiler ("bwd_gm" , gm , example_inputs , passes = passes )
259+ return compiler (
260+ "bwd_gm" , gm , example_inputs , passes = passes , dump_folder = dump_folder
261+ )
225262
226263 return fw_compiler , bw_compiler
227264
0 commit comments