Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ def parallelize_deepseekv3(
compiler_passes = get_compiler_passes_from_config(job_config)

# Create compilers with specified passes (defaults to no passes)
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
fw_compiler, bw_compiler = make_compiler_with_passes(
compiler_passes, dump_folder=job_config.job.dump_folder
)

# Create custom joint_graph_builder with deepseekv3-specific compilers
deepseekv3_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_pass=validate_flex_attention_annotation,
dump_folder=job_config.job.dump_folder,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down
65 changes: 51 additions & 14 deletions torchtitan/experiments/compiler_toolkit/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.

import contextlib
from typing import Callable, List, Optional
from pathlib import Path
from typing import Any, Callable, List, Optional

import torch
from torch._dynamo.functional_export import dynamo_graph_capture_for_export
Expand All @@ -21,8 +22,18 @@
from torchtitan.tools.logging import logger


def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None:
# TODO: make the dump rank configurable
if not dump_folder or torch.distributed.get_rank() != 0:
return

output_path = Path(dump_folder) / "compiler" / f"{name}.txt"
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(gm.print_readable(print_output=False))


def export_joint(
model, args, kwargs=None
model, args, kwargs=None, dump_folder: str | None = None
) -> tuple[JointWithDescriptors, TracingContext]:
if kwargs is None:
kwargs = {}
Expand All @@ -35,8 +46,10 @@ def export_joint(
torch.fx.traceback.preserve_node_meta(),
):
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
logger.info("Dynamo gm:")
logger.info(gm.print_readable(print_output=False))
logger.debug("Dynamo gm:")
logger.debug(gm.print_readable(print_output=False))
_dump_gm(dump_folder, gm, "dynamo_gm")

tracing_context = gm.meta["tracing_context"]

with tracing(tracing_context):
Expand Down Expand Up @@ -68,6 +81,7 @@ def joint_graph_builder(
fw_compiler: Optional[Callable] = None,
bw_compiler: Optional[Callable] = None,
joint_custom_pass: Optional[Callable] = None,
dump_folder: str | None = None,
):
"""
Build a joint forward-backward graph for the model with optional custom compilers.
Expand All @@ -79,16 +93,17 @@ def joint_graph_builder(
fw_compiler: Optional custom forward compiler function
bw_compiler: Optional custom backward compiler function
joint_custom_pass: Optional custom pass to run on the joint graph
dump_folder: Optional folder to dump the graph to
"""
assert isinstance(model_args, tuple)
for arg in model_args:
assert isinstance(arg, DTensor)
for idx, arg in enumerate(model_args):
assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}"

# get joint graph
(
joint_with_descriptors,
tracing_context,
) = export_joint(model, model_args, model_kwargs)
) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder)

# Optional validation
if joint_custom_pass is not None:
Expand Down Expand Up @@ -153,6 +168,18 @@ def __delattr__(self, name: str) -> None:
else:
super().__delattr__(name)

def state_dict(self, *args, **kwargs) -> Any:
return self.inner.state_dict(*args, **kwargs)

def load_state_dict(self, *args, **kwargs) -> Any:
return self.inner.load_state_dict(*args, **kwargs)

def name_parameters(self, *args, **kwargs) -> Any:
return self.inner.named_parameters(*args, **kwargs)

def parameters(self, *args, **kwargs) -> Any:
return self.inner.parameters(*args, **kwargs)

def forward(self, *args, **kwargs):
assert "forward" not in self._overrides, "forward cannot be overridden"

Expand All @@ -179,6 +206,7 @@ def compiler(
gm: torch.fx.GraphModule,
example_inputs,
passes: List[Callable] = None,
dump_folder: str | None = None,
):
"""
Compile a graph module by applying a sequence of compiler passes.
Expand All @@ -190,23 +218,28 @@ def compiler(
passes: List of compiler pass functions to apply. Each function should take
(gm, example_inputs) and return a transformed gm. If None, uses
DEFAULT_COMPILER_PASSES.
dump_folder: Optional folder to dump the graph to
"""
if passes is None:
passes = DEFAULT_COMPILER_PASSES

logger.info(f"{name} before compiler:")
logger.info(gm.print_readable(print_output=False))
logger.debug(f"{name} before compiler:")
logger.debug(gm.print_readable(print_output=False))
_dump_gm(dump_folder, gm, f"{name}_before_compiler")

for pass_fn in passes:
logger.info(f"Applying pass: {pass_fn.__name__}")
gm = pass_fn(gm, example_inputs)

logger.info(f"{name} after compiler:")
logger.info(gm.print_readable(print_output=False))
logger.debug(f"{name} after compiler:")
logger.debug(gm.print_readable(print_output=False))
_dump_gm(dump_folder, gm, f"{name}_after_compiler")
return gm


def make_compiler_with_passes(passes: List[Callable] = None):
def make_compiler_with_passes(
passes: List[Callable] = None, dump_folder: str | None = None
):
"""
Create forward and backward compilers with specified passes.

Expand All @@ -218,10 +251,14 @@ def make_compiler_with_passes(passes: List[Callable] = None):
"""

def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("fwd_gm", gm, example_inputs, passes=passes)
return compiler(
"fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
)

def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None:
return compiler("bwd_gm", gm, example_inputs, passes=passes)
return compiler(
"bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder
)

return fw_compiler, bw_compiler

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,17 @@ def parallelize_llama(
compiler_passes = get_compiler_passes_from_config(job_config)

# Create compilers with specified passes (defaults to no passes)
fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes)
fw_compiler, bw_compiler = make_compiler_with_passes(
compiler_passes, dump_folder=job_config.job.dump_folder
)

# Create custom joint_graph_builder with llama-specific compilers and validation
llama_joint_graph_builder = functools.partial(
joint_graph_builder,
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
joint_custom_pass=validate_flex_attention_annotation,
dump_folder=job_config.job.dump_folder,
)

# TODO: CompiledModule should take sample input as well, so that we can
Expand Down