diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 0793820ffd..c943978d30 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -118,7 +118,7 @@ def parallelize_deepseekv3( ) if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 9911ecdfd0..0764652a85 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -129,7 +129,7 @@ def parallelize_llama( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: @@ -506,7 +506,7 @@ def apply_moe_ep_tp( ) -def apply_compile(model: nn.Module, compile_config: CompileConfig): +def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). @@ -577,6 +577,22 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): fullgraph=True, ) + if ep_enabled: + compiled_fn = moe_module._run_experts_grouped_mm + + def _run_experts_grouped_mm_dynamic( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + # dynamic number of tokens in expert parallel + torch._dynamo.mark_dynamic(x, 0) + return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) + + moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic + # NOTE: We don't compile for loop code path due to an issue with unbacked symints: # https://github.com/pytorch/pytorch/issues/166460 diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 6b8dc3d5a6..adaa2ad3e8 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -119,7 +119,7 @@ def parallelize_qwen3( # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel