|
| 1 | +# Copyright 2025 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | +from typing import Dict, Union |
| 17 | + |
| 18 | +import torch |
| 19 | +from tabulate import tabulate |
| 20 | +from torch.export import ExportedProgram |
| 21 | +from torch.nn.attention import SDPBackend |
| 22 | + |
| 23 | +from executorch.devtools.backend_debug import get_delegation_info |
| 24 | +from executorch.exir import ( |
| 25 | + EdgeCompileConfig, |
| 26 | + ExecutorchProgram, |
| 27 | + to_edge_transform_and_lower, |
| 28 | +) |
| 29 | +from optimum.executorch.passes.remove_padding_idx_embedding_pass import RemovePaddingIdxEmbeddingPass |
| 30 | + |
| 31 | +from ..integrations import ( |
| 32 | + CausalLMExportableModule, |
| 33 | + MaskedLMExportableModule, |
| 34 | + MultiModalTextToTextExportableModule, |
| 35 | + Seq2SeqLMExportableModule, |
| 36 | +) |
| 37 | +from ..recipe_registry import register_recipe |
| 38 | + |
| 39 | + |
| 40 | +aten = torch.ops.aten |
| 41 | + |
| 42 | + |
| 43 | +@register_recipe("cuda") |
| 44 | +def export_to_executorch_with_cuda( |
| 45 | + model: Union[ |
| 46 | + CausalLMExportableModule, |
| 47 | + MaskedLMExportableModule, |
| 48 | + Seq2SeqLMExportableModule, |
| 49 | + MultiModalTextToTextExportableModule, |
| 50 | + ], |
| 51 | + **kwargs, |
| 52 | +): |
| 53 | + """ |
| 54 | + Export a PyTorch model to ExecuTorch w/ delegation to CUDA backend. |
| 55 | + This function also write metadata required by the ExecuTorch runtime to the .pte file. |
| 56 | + Args: |
| 57 | + model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule, MultiModalTextToTextExportableModule]): |
| 58 | + The PyTorch model to be exported to ExecuTorch. |
| 59 | + **kwargs: |
| 60 | + Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs. |
| 61 | + Returns: |
| 62 | + Dict[str, ExecutorchProgram]: |
| 63 | + A map of exported and optimized program for ExecuTorch. |
| 64 | + For encoder-decoder models or multimodal models, it may generate multiple programs. |
| 65 | + """ |
| 66 | + # Import here to avoid version conflicts. |
| 67 | + from torch._inductor.decomposition import conv1d_to_conv2d |
| 68 | + |
| 69 | + from executorch.backends.cuda.cuda_backend import CudaBackend |
| 70 | + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner |
| 71 | + |
| 72 | + def _lower_to_executorch( |
| 73 | + exported_programs: Dict[str, ExportedProgram], |
| 74 | + metadata=None, |
| 75 | + ) -> Dict[str, ExecutorchProgram]: |
| 76 | + logging.debug(f"\nExported program: {exported_programs}") |
| 77 | + |
| 78 | + # If just one exported program, the method name in the .pte for it should be "forward". |
| 79 | + if len(exported_programs) == 1: |
| 80 | + exported_programs = {"forward": next(iter(exported_programs.values()))} |
| 81 | + |
| 82 | + # CUDA backend compile spec with method name. |
| 83 | + partitioners = { |
| 84 | + key: [CudaPartitioner([CudaBackend.generate_method_name_compile_spec(key)])] |
| 85 | + for key in exported_programs.keys() |
| 86 | + } |
| 87 | + # Add decompositions for triton to generate kernels. |
| 88 | + for key, ep in exported_programs.items(): |
| 89 | + exported_programs[key] = ep.run_decompositions( |
| 90 | + { |
| 91 | + aten.conv1d.default: conv1d_to_conv2d, |
| 92 | + } |
| 93 | + ) |
| 94 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]): |
| 95 | + et_prog = to_edge_transform_and_lower( |
| 96 | + exported_programs, |
| 97 | + partitioner=partitioners, |
| 98 | + compile_config=EdgeCompileConfig( |
| 99 | + _check_ir_validity=False, |
| 100 | + _skip_dim_order=True, |
| 101 | + ), |
| 102 | + constant_methods=metadata, |
| 103 | + transform_passes=[RemovePaddingIdxEmbeddingPass()], |
| 104 | + ) |
| 105 | + et_prog = et_prog.to_executorch() |
| 106 | + pte_name = "model" |
| 107 | + for method in et_prog.methods: |
| 108 | + logging.debug(f"---------------------- Method: {method} ----------------------") |
| 109 | + logging.debug(f"\nExecuTorch program for {pte_name}.pte: {et_prog.exported_program(method).graph_module}") |
| 110 | + delegation_info = get_delegation_info(et_prog.exported_program(method).graph_module) |
| 111 | + logging.debug(f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}") |
| 112 | + logging.debug( |
| 113 | + f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}" |
| 114 | + ) |
| 115 | + return {pte_name: et_prog} |
| 116 | + |
| 117 | + # Decomposes SDPA since we don't have a flash attention kernel for it yet. |
| 118 | + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): |
| 119 | + exported_progs = model.export() |
| 120 | + |
| 121 | + if ( |
| 122 | + model.config._attn_implementation == "custom_sdpa" |
| 123 | + or model.config._attn_implementation == "custom_sdpa_ring_kv_cache" |
| 124 | + ): |
| 125 | + raise NotImplementedError( |
| 126 | + "Custom SDPA implementation is not supported for CUDA yet. Please use 'flash_attention' instead." |
| 127 | + ) |
| 128 | + |
| 129 | + return _lower_to_executorch(exported_progs, model.metadata) |
0 commit comments