Skip to content

Commit dfa2d7a

Browse files
authored
Add CUDA recipe (#158)
* Add CUDA recipe Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Write .ptd file * Address comments * Add a unit test * Lint
1 parent 8335b49 commit dfa2d7a

File tree

4 files changed

+175
-0
lines changed

4 files changed

+175
-0
lines changed

optimum/commands/export/executorch.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,20 @@ def parse_args_executorch(parser):
126126
required=False,
127127
help="Maximum sequence length for the model. If not specified, uses the model's default max_position_embeddings.",
128128
)
129+
required_group.add_argument(
130+
"--dtype",
131+
type=str,
132+
choices=["float32", "float16", "bfloat16"],
133+
required=False,
134+
help="Data type for model weights. Options: float32, float16, bfloat16. Default: float32. For quantization (int8/int4), use the --qlinear arguments.",
135+
)
136+
required_group.add_argument(
137+
"--device",
138+
type=str,
139+
choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
140+
required=False,
141+
help="Device to run the model on. Options: cpu, cuda. Default: cpu.",
142+
)
129143

130144

131145
class ExecuTorchExportCommand(BaseOptimumCLICommand):
@@ -159,6 +173,10 @@ def run(self):
159173
kwargs["qembedding_group_size"] = self.args.qembedding
160174
if self.args.max_seq_len:
161175
kwargs["max_seq_len"] = self.args.max_seq_len
176+
if hasattr(self.args, "dtype") and self.args.dtype:
177+
kwargs["dtype"] = self.args.dtype
178+
if hasattr(self.args, "device") and self.args.device:
179+
kwargs["device"] = self.args.device
162180

163181
main_export(
164182
model_name_or_path=self.args.model,

optimum/exporters/executorch/convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,6 @@ def export_to_executorch(
8686
logging.info(
8787
f"Saved exported program to {full_path} ({os.path.getsize(full_path) / (1024 * 1024):.2f} MB)"
8888
)
89+
prog.write_tensor_data_to_file(output_dir)
8990

9091
return executorch_progs
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from typing import Dict, Union
17+
18+
import torch
19+
from tabulate import tabulate
20+
from torch.export import ExportedProgram
21+
from torch.nn.attention import SDPBackend
22+
23+
from executorch.devtools.backend_debug import get_delegation_info
24+
from executorch.exir import (
25+
EdgeCompileConfig,
26+
ExecutorchProgram,
27+
to_edge_transform_and_lower,
28+
)
29+
from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass
30+
31+
from ..integrations import (
32+
CausalLMExportableModule,
33+
MaskedLMExportableModule,
34+
MultiModalTextToTextExportableModule,
35+
Seq2SeqLMExportableModule,
36+
)
37+
from ..recipe_registry import register_recipe
38+
39+
40+
aten = torch.ops.aten
41+
42+
43+
@register_recipe("cuda")
44+
def export_to_executorch_with_cuda(
45+
model: Union[
46+
CausalLMExportableModule,
47+
MaskedLMExportableModule,
48+
Seq2SeqLMExportableModule,
49+
MultiModalTextToTextExportableModule,
50+
],
51+
**kwargs,
52+
):
53+
"""
54+
Export a PyTorch model to ExecuTorch w/ delegation to CUDA backend.
55+
This function also write metadata required by the ExecuTorch runtime to the .pte file.
56+
Args:
57+
model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, MultiModalTextToTextExportableModule]):
58+
The PyTorch model to be exported to ExecuTorch.
59+
**kwargs:
60+
Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs.
61+
Returns:
62+
Dict[str, ExecutorchProgram]:
63+
A map of exported and optimized program for ExecuTorch.
64+
For encoder-decoder models or multimodal models, it may generate multiple programs.
65+
"""
66+
# Import here to avoid version conflicts.
67+
from torch._inductor.decomposition import conv1d_to_conv2d
68+
69+
from executorch.backends.cuda.cuda_backend import CudaBackend
70+
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
71+
72+
def _lower_to_executorch(
73+
exported_programs: Dict[str, ExportedProgram],
74+
metadata=None,
75+
) -> Dict[str, ExecutorchProgram]:
76+
logging.debug(f"\nExported program: {exported_programs}")
77+
78+
# If just one exported program, the method name in the .pte for it should be "forward".
79+
if len(exported_programs) == 1:
80+
exported_programs = {"forward": next(iter(exported_programs.values()))}
81+
82+
# CUDA backend compile spec with method name.
83+
partitioners = {
84+
key: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(key)])]
85+
for key in exported_programs.keys()
86+
}
87+
# Add decompositions for triton to generate kernels.
88+
for key, ep in exported_programs.items():
89+
exported_programs[key] = ep.run_decompositions(
90+
{
91+
aten.conv1d.default: conv1d_to_conv2d,
92+
}
93+
)
94+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]):
95+
et_prog = to_edge_transform_and_lower(
96+
exported_programs,
97+
partitioner=partitioners,
98+
compile_config=EdgeCompileConfig(
99+
_check_ir_validity=False,
100+
_skip_dim_order=True,
101+
),
102+
constant_methods=metadata,
103+
transform_passes=[RemovePaddingIdxEmbeddingPass()],
104+
)
105+
et_prog = et_prog.to_executorch()
106+
pte_name = "model"
107+
for method in et_prog.methods:
108+
logging.debug(f"---------------------- Method: {method} ----------------------")
109+
logging.debug(f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}")
110+
delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module)
111+
logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}")
112+
logging.debug(
113+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
114+
)
115+
return {pte_name: et_prog}
116+
117+
# Decomposes SDPA since we don't have a flash attention kernel for it yet.
118+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
119+
exported_progs = model.export()
120+
121+
if (
122+
model.config._attn_implementation == "custom_sdpa"
123+
or model.config._attn_implementation == "custom_sdpa_ring_kv_cache"
124+
):
125+
raise NotImplementedError(
126+
"Custom SDPA implementation is not supported for CUDA yet. Please use 'flash_attention' instead."
127+
)
128+
129+
return _lower_to_executorch(exported_progs, model.metadata)

tests/models/test_modeling_voxtral.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import gc
1717
import logging
1818
import os
19+
import subprocess
1920
import sys
21+
import tempfile
2022
import unittest
2123

2224
import pytest
@@ -324,3 +326,28 @@ def test_voxtral_audio_text_to_text_generation_with_custom_sdpa_kv_cache_8da4w_8
324326
self.assertTrue(
325327
check_multimodal_output_quality(model_id, generated_tokens, conversation, max_perplexity_threshold=5)
326328
)
329+
330+
@slow
331+
@pytest.mark.run_slow
332+
@pytest.mark.skipif(is_linux_ci, reason="OOM")
333+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA backend required")
334+
def test_voxtral_export_to_executorch_cuda_recipe(self):
335+
model_id = "mistralai/Voxtral-Mini-3B-2507"
336+
task = "multimodal-text-to-text"
337+
recipe = "cuda"
338+
output_subdir = "executorch"
339+
340+
with tempfile.TemporaryDirectory() as tempdir:
341+
output_dir = os.path.join(tempdir, output_subdir)
342+
cmd = (
343+
"optimum-cli export executorch "
344+
f"--model {model_id} "
345+
f"--task {task} "
346+
f"--recipe {recipe} "
347+
"--dtype bfloat16 "
348+
"--device cuda:0 "
349+
"--max_seq_len 1024 "
350+
f"--output_dir {output_dir}"
351+
)
352+
subprocess.run(cmd, shell=True, check=True)
353+
self.assertTrue(os.path.exists(os.path.join(output_dir, "model.pte")))

0 commit comments

Comments
 (0)