Skip to content

Commit ca4828c

Browse files
authored
Wav2vec2 onboarding (#571)
Added support for "facebook/wav2vec2-base-960h" model via AutoModelForCTC class. --------- Signed-off-by: Tanisha Chawada <tchawada@qti.qualcomm.com>
1 parent c3aa753 commit ca4828c

File tree

7 files changed

+541
-2
lines changed

7 files changed

+541
-2
lines changed

QEfficient/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
import os
99
import warnings
1010

11+
import QEfficient.utils.model_registery # noqa: F401
1112
from QEfficient.utils import custom_format_warning
13+
from QEfficient.utils.logging_utils import logger
1214

1315
# For faster downloads via hf_transfer
1416
# This code is put above import statements as this needs to be executed before
1517
# hf_transfer is imported (will happen on line 15 via leading imports)
1618
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
1719
# Placeholder for all non-transformer models registered in QEfficient
18-
import QEfficient.utils.model_registery # noqa: F401
19-
from QEfficient.utils.logging_utils import logger
20+
2021

2122
# custom warning for the better logging experience
2223
warnings.formatwarning = custom_format_warning
@@ -43,6 +44,7 @@ def check_qaic_sdk():
4344
from QEfficient.base import (
4445
QEFFAutoModel,
4546
QEFFAutoModelForCausalLM,
47+
QEFFAutoModelForCTC,
4648
QEFFAutoModelForImageTextToText,
4749
QEFFAutoModelForSpeechSeq2Seq,
4850
QEFFCommonLoader,
@@ -63,6 +65,7 @@ def check_qaic_sdk():
6365
"cloud_ai_100_exec_kv",
6466
"QEFFAutoModel",
6567
"QEFFAutoModelForCausalLM",
68+
"QEFFAutoModelForCTC",
6669
"QEffAutoPeftModelForCausalLM",
6770
"QEFFAutoModelForImageTextToText",
6871
"QEFFAutoModelForSpeechSeq2Seq",

QEfficient/base/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from QEfficient.transformers.models.modeling_auto import ( # noqa: F401
1010
QEFFAutoModel,
1111
QEFFAutoModelForCausalLM,
12+
QEFFAutoModelForCTC,
1213
QEFFAutoModelForImageTextToText,
1314
QEFFAutoModelForSpeechSeq2Seq,
1415
)

QEfficient/transformers/models/modeling_auto.py

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from transformers import (
1717
AutoModel,
1818
AutoModelForCausalLM,
19+
AutoModelForCTC,
1920
AutoModelForImageTextToText,
2021
AutoModelForSpeechSeq2Seq,
2122
PreTrainedTokenizer,
@@ -3077,3 +3078,287 @@ def generate(
30773078
generated_ids=generated_ids,
30783079
perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
30793080
)
3081+
3082+
3083+
class QEFFAutoModelForCTC(QEFFTransformersBase):
3084+
"""
3085+
The QEFFAutoModelForCTC class is designed for transformer models with a Connectionist Temporal Classification (CTC) speech-to-text head,
3086+
including Wav2Vec2 and other encoder-only speech models optimized for alignment-free transcription.
3087+
Although it is possible to initialize the class directly, we highly recommend using the ``from_pretrained`` method for initialization.
3088+
3089+
``Mandatory`` Args:
3090+
:model (nn.Module): PyTorch model
3091+
3092+
.. code-block:: python
3093+
import torchaudio
3094+
from QEfficient import QEFFAutoModelForCTC
3095+
from transformers import AutoProcessor
3096+
3097+
# Initialize the model using from_pretrained similar to transformers.AutoModelForCTC.
3098+
model=QEFFAutoModelForCTC.from_pretrained(model_name)
3099+
3100+
# Now you can directly compile the model for Cloud AI 100
3101+
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
3102+
3103+
#prepare input
3104+
processor = AutoProcessor.from_pretrained(model_name)
3105+
input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile
3106+
3107+
# Resample the input_audio if necessary
3108+
if input_audio.shape[0] > 1:
3109+
input_audio = input_audio.mean(dim=0)
3110+
if sample_rate != 16000:
3111+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
3112+
input_audio = resampler(input_audio)
3113+
3114+
# You can now execute the model
3115+
out = model.generate(processor,inputs=input_audio)
3116+
"""
3117+
3118+
_hf_auto_class = AutoModelForCTC
3119+
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
3120+
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
3121+
3122+
def __init__(self, model: nn.Module, **kwargs):
3123+
super().__init__(model, **kwargs)
3124+
self.model.base_model.config.use_cache = True
3125+
3126+
self.hash_params["qeff_auto_class"] = self.__class__.__name__
3127+
3128+
@classmethod
3129+
@with_replaced_quantizers
3130+
def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **kwargs):
3131+
"""
3132+
This method serves as the easiest entry point into using QEfficient. The interface is designed to be similar to transformers.AutoModelForCTC.
3133+
Once the model is initialized, you can use other methods such as export, compile, and generate on the same object.
3134+
3135+
Args:
3136+
pretrained_model_name_or_path (str): The name or path of the pre-trained model.
3137+
3138+
.. code-block:: python
3139+
3140+
import torchaudio
3141+
from QEfficient import QEFFAutoModelForCTC
3142+
from transformers import AutoProcessor
3143+
3144+
# Initialize the model using from_pretrained similar to transformers.AutoModelForCTC.
3145+
model=QEFFAutoModelForCTC.from_pretrained(model_name)
3146+
3147+
# Now you can directly compile the model for Cloud AI 100
3148+
model.compile(num_cores=16) # Considering you have a Cloud AI 100 SKU
3149+
3150+
#prepare input
3151+
processor = AutoProcessor.from_pretrained(model_name)
3152+
input_audio, sample_rate = [...] # audio data loaded in via some external audio package, such as librosa or soundfile
3153+
3154+
# Resample the input_audio if necessary
3155+
if input_audio.shape[0] > 1:
3156+
input_audio = input_audio.mean(dim=0)
3157+
if sample_rate != 16000:
3158+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
3159+
input_audio = resampler(input_audio)
3160+
3161+
# You can now execute the model
3162+
out = model.generate(processor,inputs=input_audio)
3163+
"""
3164+
if kwargs.get("attn_implementation", None) not in {None, "eager"}:
3165+
logger.warning('Updating attn_implementation="eager"')
3166+
3167+
if kwargs.get("low_cpu_mem_usage", None):
3168+
logger.warning("Updating low_cpu_mem_usage=False")
3169+
3170+
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
3171+
3172+
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
3173+
3174+
# This is support models that should be classified to in a different auto class but transformers load them via this class
3175+
kv_offload = kwargs.pop("kv_offload", None)
3176+
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
3177+
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
3178+
model, kv_offload=kv_offload, **kwargs
3179+
)
3180+
3181+
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs)
3182+
3183+
@property
3184+
def get_model_config(self) -> dict:
3185+
return self.model.config.__dict__
3186+
3187+
def export(self, export_dir: Optional[str] = None) -> str:
3188+
"""
3189+
Exports the model to ``ONNX`` format using ``torch.onnx.export``.
3190+
3191+
``Optional`` Args:
3192+
:export_dir (str, optional): The directory path to store ONNX-graph.
3193+
3194+
Returns:
3195+
:str: Path of the generated ``ONNX`` graph.
3196+
"""
3197+
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
3198+
seq_len = constants.WAV2VEC2_MAX_SEQ_LEN
3199+
3200+
example_inputs = {
3201+
"input_values": torch.zeros((bs, seq_len), dtype=torch.float32),
3202+
}
3203+
3204+
dynamic_axes = {"input_values": {0: "batch_size", 1: "seq_len"}}
3205+
3206+
output_names = ["logits"]
3207+
3208+
return self._export(
3209+
example_inputs,
3210+
output_names,
3211+
dynamic_axes,
3212+
export_dir=export_dir,
3213+
)
3214+
3215+
def compile(
3216+
self,
3217+
onnx_path: Optional[str] = None,
3218+
compile_dir: Optional[str] = None,
3219+
*,
3220+
seq_len: Union[int, List[int]] = 480000,
3221+
batch_size: int = 1,
3222+
num_devices: int = 1,
3223+
num_cores: int = 16, # FIXME: Make this mandatory arg
3224+
mxfp6_matmul: bool = False,
3225+
**compiler_options,
3226+
) -> str:
3227+
"""
3228+
This method compiles the exported ``ONNX`` model using the Cloud AI 100 Platform SDK compiler binary found at ``/opt/qti-aic/exec/qaic-exec`` and generates a ``qpc`` package.
3229+
If the model has not been exported yet, this method will handle the export process.
3230+
You can pass any other arguments that the `qaic-exec` takes as extra kwargs.
3231+
3232+
``Optional`` Args:
3233+
:onnx_path (str, optional): Path to pre-exported onnx model.
3234+
:compile_dir (str, optional): Path for saving the qpc generated.
3235+
:seq_len (Union[int, List[int]]): The length of the prompt should be less that ``seq_len``. ``Defaults to 32``.
3236+
:batch_size (int, optional): Batch size. ``Defaults to 1``.
3237+
:num_devices (int): Number of devices the model needs to be compiled for. Defaults to 1.
3238+
:num_cores (int): Number of cores used to compile the model.
3239+
:mxfp6_matmul (bool, optional): Whether to use ``mxfp6`` compression for weights. ``Defaults to False``.
3240+
:compiler_options (dict, optional): Additional compiler options.
3241+
3242+
For QAIC Compiler: Extra arguments for qaic-exec can be passed.
3243+
:aic_enable_depth_first (bool, optional): Enables DFS with default memory size. ``Defaults to False``.
3244+
:allow_mxint8_mdp_io (bool, optional): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
3245+
3246+
Params are converted to flags as below:
3247+
3248+
- aic_hw_version=ai100 -> -aic-hw-version=ai100
3249+
- aic_hw_version=ai200 -> -aic-hw-version=ai200
3250+
3251+
For QNN Compiler: Following arguments can be passed.
3252+
:enable_qnn (bool): Enables QNN Compilation.
3253+
:qnn_config (str): Path of QNN Config parameters file. Any extra parameters for QNN compilation can be passed via this file.
3254+
3255+
Returns:
3256+
:str: Path of the compiled ``qpc`` package.
3257+
"""
3258+
3259+
specializations = [
3260+
{"batch_size": batch_size, "seq_len": sl} for sl in (seq_len if isinstance(seq_len, list) else [seq_len])
3261+
]
3262+
3263+
return self._compile(
3264+
onnx_path=onnx_path,
3265+
compile_dir=compile_dir,
3266+
compile_only=True,
3267+
specializations=specializations,
3268+
convert_to_fp16=True,
3269+
mxfp6_matmul=mxfp6_matmul,
3270+
mdp_ts_num_devices=num_devices,
3271+
aic_num_cores=num_cores,
3272+
**compiler_options,
3273+
)
3274+
3275+
def generate(
3276+
self,
3277+
processor,
3278+
inputs: torch.Tensor,
3279+
device_ids: List[int] = None,
3280+
runtime_ai100: bool = True,
3281+
) -> Union[torch.Tensor, np.ndarray]:
3282+
"""
3283+
This method generates output by executing PyTorch runtime or the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
3284+
``Mandatory`` Args:
3285+
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
3286+
:processor (AutoProcessor): The Processor to use for encoding the waveform.
3287+
``optional`` Args:
3288+
:device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
3289+
:runtime_ai100 (bool, optional): ``AI_100`` and ``PyTorch`` runtime is supported as of now. Defaults to ``True`` for ``AI_100`` runtime.
3290+
Returns:
3291+
:dict: Output from the ``AI_100`` or ``PyTorch`` runtime.
3292+
"""
3293+
# AI_100 runtime
3294+
if runtime_ai100:
3295+
if not isinstance(self.qpc_path, Path):
3296+
raise TypeError("Please run compile API first!")
3297+
3298+
return self.cloud_ai_100_feature_generate(processor, inputs=inputs, device_ids=device_ids)
3299+
# PyTorch runtime
3300+
else:
3301+
return self.pytorch_feature_generate(processor, model=self.model, inputs=inputs)
3302+
3303+
def cloud_ai_100_feature_generate(
3304+
self,
3305+
processor,
3306+
inputs: torch.Tensor,
3307+
device_ids: List[int] = [0],
3308+
) -> np.ndarray:
3309+
"""
3310+
Generates features with list of prompts using AI 100 runtime.
3311+
3312+
``Mandatory`` Args:
3313+
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
3314+
:processor (AutoProcessor): The Processor to use for encoding the waveform.
3315+
``Optional`` Args:
3316+
device_ids (List[int], optional): A list of device IDs to use for the session. Defaults to [0].
3317+
3318+
"""
3319+
3320+
if self.qpc_session is None:
3321+
self.qpc_session = QAICInferenceSession(str(self.qpc_path), device_ids)
3322+
self.batch_size = self.qpc_session.bindings[0].dims[0]
3323+
3324+
# Dynamic switching to closest seq_Len based on input_ids_len
3325+
inputs = processor(inputs, return_tensors="pt")
3326+
input_ids_len = inputs["input_values"].shape[-1]
3327+
3328+
for allowed_shape in self.qpc_session.allowed_shapes:
3329+
seq_len_allowed = allowed_shape[1][1][1]
3330+
3331+
if seq_len_allowed >= input_ids_len:
3332+
self.seq_len = seq_len_allowed
3333+
break
3334+
3335+
# To handle single seq_len as we can't fetch allowed shapes for single seq_len
3336+
self.seq_len = self.qpc_session.bindings[0].dims[1] if not hasattr(self, "seq_len") else self.seq_len
3337+
input_values = np.array(
3338+
torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0)
3339+
)
3340+
inputs = dict(input_values=input_values)
3341+
outputs = self.qpc_session.run(inputs)
3342+
logits = outputs["logits"]
3343+
predicted_ids = np.argmax(logits, axis=-1)
3344+
transcriptions = processor.batch_decode(torch.tensor(predicted_ids))
3345+
return transcriptions
3346+
3347+
def pytorch_feature_generate(self, processor, model, inputs: Union[torch.Tensor, np.ndarray]) -> List[torch.Tensor]:
3348+
"""
3349+
Generates features from a list of text prompts using a PyTorch model.
3350+
3351+
``Mandatory`` Args:
3352+
:model: The transformed PyTorch model used for generating features.
3353+
:inputs (Union[torch.Tensor, np.ndarray]): inputs to run the execution.
3354+
:processor (AutoProcessor): The Processor to use for encoding the waveform.
3355+
3356+
"""
3357+
input_values = processor(
3358+
inputs[0], return_tensors="pt", max_length=self.seq_len, truncation=True, padding="max_length"
3359+
).input_values
3360+
logits = model(input_values[0]).logits
3361+
logits = logits.detach().numpy()
3362+
predicted_ids = np.argmax(logits, axis=-1)
3363+
transcriptions = processor.batch_decode(predicted_ids)
3364+
return transcriptions

QEfficient/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def get_models_dir():
122122
# Gemma3 Constant
123123
GEMMA3_MAX_POSITION_EMBEDDINGS = 32768
124124

125+
# Wav2Vec2 Constant
126+
WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec)
127+
125128

126129
class Constants:
127130
# Export Constants.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Speech Recognition with Wav2Vec2
2+
This directory contains an example script of how to use the AutoModelForCTC class. (for now, Wav2Vec2 models on audio <30 seconds only has been validated)
3+
4+
## Required packages:
5+
- `librosa==0.10.2`
6+
- `soundfile==0.13.1`
7+
8+
You can install them using pip:
9+
```sh
10+
pip install librosa==0.10.2 soundfile==0.13.1
11+
```
12+
13+
To run example script after package installations:
14+
```sh
15+
python run_wav2vec2_inference.py
16+
```
17+
18+
Expected output for given data sample:
19+
```sh
20+
MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL
21+
```
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from datasets import load_dataset
9+
from transformers import AutoProcessor
10+
11+
from QEfficient import QEFFAutoModelForCTC
12+
13+
base_model_name = "facebook/wav2vec2-base-960h"
14+
15+
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
16+
data = ds[0]["audio"]["array"]
17+
# reshape to so shape corresponds to data with batch size 1
18+
data = data.reshape(-1)
19+
sample_rate = ds[0]["audio"]["sampling_rate"]
20+
processor = AutoProcessor.from_pretrained(base_model_name)
21+
22+
model = QEFFAutoModelForCTC.from_pretrained(base_model_name)
23+
model.compile(num_cores=16)
24+
print(model.generate(processor, inputs=data))

0 commit comments

Comments
 (0)