diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py deleted file mode 100644 index df997aabe9..0000000000 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from contextlib import contextmanager - -from torchtitan.config import JobConfig - - -@contextmanager -def disable_compile(job_config: JobConfig): - """Context manager to temporarily disable compilation.""" - original_value = job_config.compile.enable - job_config.compile.enable = False - try: - yield - finally: - job_config.compile.enable = original_value diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index a859415c1c..3e4dff5013 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -18,7 +18,6 @@ from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.expert_parallel import ExpertParallel -from torchtitan.experiments.compiler_toolkit.common_utils import disable_compile from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, @@ -103,12 +102,13 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + assert ( + not job_config.compile.enable + ), "compile.enable should be False in the compiler toolkit style workflow." annotate_model() - # Disable torch.compile over the model in the compiler toolkit style workflow - with disable_compile(job_config): - model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) + model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) # TODO: CompiledModule should take sample input as well, so that we can # compile ahead of time. diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index b8d00db39e..8ee937324c 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -15,7 +15,6 @@ from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims -from torchtitan.experiments.compiler_toolkit.common_utils import disable_compile from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, @@ -105,12 +104,13 @@ def parallelize_llama( parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + assert ( + not job_config.compile.enable + ), "compile.enable should be False in the compiler toolkit style workflow." annotate_model() - # Disable torch.compile over the model in the compiler toolkit style workflow - with disable_compile(job_config): - model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) + model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) # TODO: CompiledModule should take sample input as well, so that we can # compile ahead of time.