From ed90b66853ee4dd8f99306a1ed0d7690d28f8411 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 15:21:55 -0800 Subject: [PATCH 1/3] Update (base update) [ghstack-poisoned] --- .gitignore | 3 ++ scripts/dry_run.py | 3 ++ .../deepseek_v3/parallelize.py | 5 +- .../compiler_toolkit/graph_utils.py | 50 ++++++++++++++----- .../compiler_toolkit/llama3/parallelize.py | 5 +- torchtitan/train.py | 3 +- 6 files changed, 53 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 45a8f5752a..415631ff9c 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,6 @@ Sessionx.vim # env files .env + +# Vibe coding +.claude diff --git a/scripts/dry_run.py b/scripts/dry_run.py index 2552ca0d78..fa8e1b4c17 100644 --- a/scripts/dry_run.py +++ b/scripts/dry_run.py @@ -151,6 +151,9 @@ def __init__(self, job_config: JobConfig): logger.info("Configuration is ready for training execution.") logger.info("=" * 80) + def train(self): + return + if __name__ == "__main__": main(DryRunTrainer) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index bc6859af61..20ad17f301 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -80,7 +80,9 @@ def parallelize_deepseekv3( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( @@ -88,6 +90,7 @@ def parallelize_deepseekv3( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index aee089cad9..db998aa170 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +from pathlib import Path from typing import Callable, List, Optional import torch @@ -21,8 +22,18 @@ from torchtitan.tools.logging import logger +def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: + # TODO: make the dump rank configurable + if not dump_folder or torch.distributed.get_rank() != 0: + return + + output_path = Path(dump_folder) / "compiler" / f"{name}.txt" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(gm.print_readable(print_output=False)) + + def export_joint( - model, args, kwargs=None + model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: if kwargs is None: kwargs = {} @@ -35,8 +46,10 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False)) + logger.debug("Dynamo gm:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, "dynamo_gm") + tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -68,6 +81,7 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_pass: Optional[Callable] = None, + dump_folder: str | None = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -79,16 +93,17 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_pass: Optional custom pass to run on the joint graph + dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) - for arg in model_args: - assert isinstance(arg, DTensor) + for idx, arg in enumerate(model_args): + assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" # get joint graph ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs) + ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) # Optional validation if joint_custom_pass is not None: @@ -179,6 +194,7 @@ def compiler( gm: torch.fx.GraphModule, example_inputs, passes: List[Callable] = None, + dump_folder: str | None = None, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -194,19 +210,23 @@ def compiler( if passes is None: passes = DEFAULT_COMPILER_PASSES - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} before compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_before_compiler") for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) + logger.debug(f"{name} after compiler:") + logger.debug(gm.print_readable(print_output=False)) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") return gm -def make_compiler_with_passes(passes: List[Callable] = None): +def make_compiler_with_passes( + passes: List[Callable] = None, dump_folder: str | None = None +): """ Create forward and backward compilers with specified passes. @@ -218,10 +238,14 @@ def make_compiler_with_passes(passes: List[Callable] = None): """ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs, passes=passes) + return compiler( + "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs, passes=passes) + return compiler( + "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder + ) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index e3dca203e9..62def3ef00 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -67,7 +67,9 @@ def parallelize_llama( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) # Create custom joint_graph_builder with llama-specific compilers and validation llama_joint_graph_builder = functools.partial( @@ -75,6 +77,7 @@ def parallelize_llama( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, + dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/train.py b/torchtitan/train.py index 18a876c4bb..5cfab998b2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -735,7 +735,8 @@ def main(trainer_class: type[Trainer]) -> None: raise else: trainer.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed") From c4308bff357b7693b18d303eabdc272d8d127b50 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 12 Nov 2025 15:21:55 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- .../experiments/compiler_toolkit/graph_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index db998aa170..fe246c1af0 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -6,7 +6,7 @@ import contextlib from pathlib import Path -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional import torch from torch._dynamo.functional_export import dynamo_graph_capture_for_export @@ -168,6 +168,18 @@ def __delattr__(self, name: str) -> None: else: super().__delattr__(name) + def state_dict(self, *args, **kwargs) -> Any: + return self.inner.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs) -> Any: + return self.inner.load_state_dict(*args, **kwargs) + + def name_parameters(self, *args, **kwargs) -> Any: + return self.inner.named_parameters(*args, **kwargs) + + def parameters(self, *args, **kwargs) -> Any: + return self.inner.parameters(*args, **kwargs) + def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" From c6d47e8f7b9ba1d7597aa80554da0f112822473f Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 13 Nov 2025 00:18:22 -0800 Subject: [PATCH 3/3] Update (base update) [ghstack-poisoned] --- .../deepseek_v3/parallelize.py | 5 +- .../compiler_toolkit/graph_utils.py | 51 +++++-------------- .../compiler_toolkit/llama3/parallelize.py | 5 +- 3 files changed, 15 insertions(+), 46 deletions(-) diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 20ad17f301..bc6859af61 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -80,9 +80,7 @@ def parallelize_deepseekv3( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes( - compiler_passes, dump_folder=job_config.job.dump_folder - ) + fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( @@ -90,7 +88,6 @@ def parallelize_deepseekv3( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, - dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 413ea066fb..aee089cad9 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import contextlib -from pathlib import Path from typing import Callable, List, Optional import torch @@ -22,18 +21,8 @@ from torchtitan.tools.logging import logger -def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: - # TODO: make the dump rank configurable - if not dump_folder or torch.distributed.get_rank() != 0: - return - - output_path = Path(dump_folder) / "compiler" / f"{name}.txt" - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(gm.print_readable(print_output=False)) - - def export_joint( - model, args, kwargs=None, dump_folder: str | None = None + model, args, kwargs=None ) -> tuple[JointWithDescriptors, TracingContext]: if kwargs is None: kwargs = {} @@ -46,10 +35,8 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.debug("Dynamo gm:") - logger.debug(gm.print_readable(print_output=False)) - _dump_gm(dump_folder, gm, "dynamo_gm") - + logger.info("Dynamo gm:") + logger.info(gm.print_readable(print_output=False)) tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -81,7 +68,6 @@ def joint_graph_builder( fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, joint_custom_pass: Optional[Callable] = None, - dump_folder: str | None = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -93,17 +79,16 @@ def joint_graph_builder( fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function joint_custom_pass: Optional custom pass to run on the joint graph - dump_folder: Optional folder to dump the graph to """ assert isinstance(model_args, tuple) - for idx, arg in enumerate(model_args): - assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" + for arg in model_args: + assert isinstance(arg, DTensor) # get joint graph ( joint_with_descriptors, tracing_context, - ) = export_joint(model, model_args, model_kwargs, dump_folder=dump_folder) + ) = export_joint(model, model_args, model_kwargs) # Optional validation if joint_custom_pass is not None: @@ -194,7 +179,6 @@ def compiler( gm: torch.fx.GraphModule, example_inputs, passes: List[Callable] = None, - dump_folder: str | None = None, ): """ Compile a graph module by applying a sequence of compiler passes. @@ -206,28 +190,23 @@ def compiler( passes: List of compiler pass functions to apply. Each function should take (gm, example_inputs) and return a transformed gm. If None, uses DEFAULT_COMPILER_PASSES. - dump_folder: Optional folder to dump the graph to """ if passes is None: passes = DEFAULT_COMPILER_PASSES - logger.debug(f"{name} before compiler:") - logger.debug(gm.print_readable(print_output=False)) - _dump_gm(dump_folder, gm, f"{name}_before_compiler") + logger.info(f"{name} before compiler:") + logger.info(gm.print_readable(print_output=False)) for pass_fn in passes: logger.info(f"Applying pass: {pass_fn.__name__}") gm = pass_fn(gm, example_inputs) - logger.debug(f"{name} after compiler:") - logger.debug(gm.print_readable(print_output=False)) - _dump_gm(dump_folder, gm, f"{name}_after_compiler") + logger.info(f"{name} after compiler:") + logger.info(gm.print_readable(print_output=False)) return gm -def make_compiler_with_passes( - passes: List[Callable] = None, dump_folder: str | None = None -): +def make_compiler_with_passes(passes: List[Callable] = None): """ Create forward and backward compilers with specified passes. @@ -239,14 +218,10 @@ def make_compiler_with_passes( """ def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler( - "fwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder - ) + return compiler("fwd_gm", gm, example_inputs, passes=passes) def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler( - "bwd_gm", gm, example_inputs, passes=passes, dump_folder=dump_folder - ) + return compiler("bwd_gm", gm, example_inputs, passes=passes) return fw_compiler, bw_compiler diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 62def3ef00..e3dca203e9 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -67,9 +67,7 @@ def parallelize_llama( compiler_passes = get_compiler_passes_from_config(job_config) # Create compilers with specified passes (defaults to no passes) - fw_compiler, bw_compiler = make_compiler_with_passes( - compiler_passes, dump_folder=job_config.job.dump_folder - ) + fw_compiler, bw_compiler = make_compiler_with_passes(compiler_passes) # Create custom joint_graph_builder with llama-specific compilers and validation llama_joint_graph_builder = functools.partial( @@ -77,7 +75,6 @@ def parallelize_llama( fw_compiler=fw_compiler, bw_compiler=bw_compiler, joint_custom_pass=validate_flex_attention_annotation, - dump_folder=job_config.job.dump_folder, ) # TODO: CompiledModule should take sample input as well, so that we can