Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion install_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def install_dep_from_source():
"-m",
"pip",
"install",
"git+https://github.com/huggingface/transformers@91393fe4cc3266a05bc0d129e34ff5f761bb46e2#egg=transformers", # 4.56.1
"git+https://github.com/huggingface/transformers@a5c903f877fda21e739027eed133e03162eb7712#egg=transformers", # Current tip of main https://github.com/huggingface/transformers/pull/42260.
]
)
subprocess.check_call(
Expand Down
5 changes: 3 additions & 2 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from pathlib import Path
from typing import TYPE_CHECKING

from ...exporters import TasksManager
from transformers.pipelines import get_supported_tasks

from ..base import BaseOptimumCLICommand, CommandInfo


Expand Down Expand Up @@ -46,7 +47,7 @@ def parse_args_executorch(parser):
default="text-generation",
help=(
"The task to export the model for. Available tasks depend on the model, but are among:"
f" {str(TasksManager.get_all_tasks())}."
f" {str(get_supported_tasks())}."
),
)
required_group.add_argument(
Expand Down
12 changes: 8 additions & 4 deletions optimum/executorch/attentions/custom_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def custom_sdpa_with_start_pos_forward(
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
position_ids: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -56,10 +57,10 @@ def custom_sdpa_with_start_pos_forward(
# Calculate the input pos from attention mask.
# Branch out for float vs bool mask
# assert attention_mask.dim() == 2, f"attention_mask must be a 2D matrix."
attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1])
first_row_mask = attention_mask[0, :]
# [0, 0, 0, 0, -inf, -inf, -inf, -inf], start_pos = 3
start_pos = torch.argmin(first_row_mask.to(torch.long)).item() - 1
assert (
position_ids is not None
), "position_ids must be provided to find start position for causal attention"
start_pos = position_ids[0][0].item()
else:
start_pos = 0

Expand Down Expand Up @@ -95,6 +96,7 @@ def _custom_sdpa_for_ring_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Union[torch.Tensor, "BlockMask"], # noqa
position_ids: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
softcap: Optional[float] = None,
head_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -122,6 +124,7 @@ def _custom_sdpa_for_ring_kv_cache(
key,
value,
attention_mask,
position_ids,
scaling,
softcap,
head_mask,
Expand All @@ -134,6 +137,7 @@ def _custom_sdpa_for_ring_kv_cache(
key,
value,
attention_mask,
position_ids,
scaling,
softcap,
head_mask,
Expand Down
15 changes: 5 additions & 10 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
PreTrainedTokenizer,
add_start_docstrings,
)
from transformers.configuration_utils import PretrainedConfig
from transformers.pipelines import get_task
from transformers.processing_utils import ProcessorMixin
from transformers.utils import is_offline_mode

Expand All @@ -46,13 +46,11 @@
)
from executorch.kernels import quantized # noqa

from ..exporters import TasksManager
from ..exporters.executorch import main_export
from ..exporters.executorch.utils import (
process_conversation_inputs,
verify_eos_tokens_in_pretrained_tokenizer,
)
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
from ..utils.file_utils import find_files_matching_pattern
from .stats import Stats

Expand All @@ -63,7 +61,7 @@
logger = logging.getLogger(__name__)


class ExecuTorchModelBase(OptimizedModel, ABC):
class ExecuTorchModelBase(ABC):
"""
ExecuTorch model for inference using the ExecuTorch Runtime.

Expand Down Expand Up @@ -99,8 +97,6 @@ def __init__(
models: Dict[str, "ExecuTorchModule"],
config: "PretrainedConfig",
):
super().__init__(model=None, config=config)

if self.__class__.auto_model_class is None:
raise ValueError(
f"Class {self.__class__.__name__} must set auto_model_class. "
Expand Down Expand Up @@ -268,6 +264,7 @@ def _export(
cls,
model_id: str,
recipe: str,
task: Optional[str] = None,
config: Optional[PretrainedConfig] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
Expand All @@ -278,9 +275,8 @@ def _export(
local_files_only: bool = False,
**kwargs,
) -> Dict[str, "ExecuTorchModule"]:
task = kwargs.pop("task", None)
inferred_task = TasksManager.infer_task_from_model(cls.auto_model_class) if not task else task
logging.info(f"Inferred task from model class: {inferred_task}")
inferred_task = get_task(model_id) if not task else task
logging.info(f"Using task: {inferred_task}")

save_dir = TemporaryDirectory(prefix="executorch_export_")
save_dir_path = Path(save_dir.name)
Expand Down Expand Up @@ -316,7 +312,6 @@ def _save_pretrained(self, save_directory):
raise NotImplementedError

@classmethod
@add_start_docstrings(FROM_PRETRAINED_START_DOCSTRING)
def from_pretrained(
cls,
model_id: Union[str, Path],
Expand Down
5 changes: 2 additions & 3 deletions optimum/exporters/executorch/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
from pathlib import Path
from typing import Union

from transformers.integrations.executorch import sdpa_mask_without_vmap
from transformers.masking_utils import AttentionMaskInterface
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

from optimum.executorch.attentions.custom_sdpa import custom_sdpa_with_start_pos_forward
Expand All @@ -29,7 +28,7 @@


AttentionInterface.register("custom_sdpa", custom_sdpa_with_start_pos_forward)
AttentionMaskInterface.register("custom_sdpa", sdpa_mask_without_vmap)
AttentionMaskInterface.register("custom_sdpa", ALL_MASK_ATTENTION_FUNCTIONS["sdpa"])


def export_to_executorch(
Expand Down
8 changes: 3 additions & 5 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
)
from transformers.integrations.executorch import (
TorchExportableModuleForDecoderOnlyLM,
sdpa_mask_without_vmap,
)
from transformers.masking_utils import AttentionMaskInterface
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
Expand Down Expand Up @@ -269,7 +268,7 @@ def _register_custom_attention(self, exportable_module: torch.nn.Module):
if self.use_custom_sdpa:
if self.use_custom_kv_cache:
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", ALL_MASK_ATTENTION_FUNCTIONS["sdpa"])
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
Expand Down Expand Up @@ -471,15 +470,14 @@ def _prepare_export_inputs(self):
return example_input_ids, example_cache_position, dynamic_shapes, strict

def _register_custom_attention(self, exportable_module: torch.nn.Module):
from transformers.integrations.executorch import sdpa_mask_without_vmap
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface

if self.use_custom_sdpa:
if self.use_custom_kv_cache:
_custom_sdpa_for_ring_kv_cache = get_custom_sdpa_for_ring_kv_cache(exportable_module)
AttentionInterface.register("custom_sdpa_ring_kv_cache", _custom_sdpa_for_ring_kv_cache)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", sdpa_mask_without_vmap)
AttentionMaskInterface.register("custom_sdpa_ring_kv_cache", ALL_MASK_ATTENTION_FUNCTIONS["sdpa"])
# Manually set the attention implementation to custom_sdpa_ring_kv_cache
# This handles both regular sdpa and one for sliding window/local attention
exportable_module.model.model.config._attn_implementation = "custom_sdpa_ring_kv_cache"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]

dependencies = [
"optimum~=1.24",
"optimum~=2.0.0",
"executorch>=1.0.0",
"transformers==4.56.1",
"pytorch-tokenizers>=1.0.1",
Expand Down
Loading