Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
# -----------------------------------------------------------------------------

import os
import warnings

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger

# ----------------------------------------------------------------------------- #
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# DO NOT ADD ANY CODE ABOVE THIS LINE
# Please contact maintainers if you must edit this file above this line.
# ----------------------------------------------------------------------------- #
# Placeholder for all non-transformer models registered in QEfficient
import warnings # noqa: I001

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger


# custom warning for the better logging experience
Expand Down
41 changes: 34 additions & 7 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
self.prefill_enabled = False
self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -179,6 +181,7 @@ def _export(
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
prefill_only: Optional[bool] = False,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -207,7 +210,10 @@ def _export(

# Return early if ONNX already exists
if onnx_path.is_file():
self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
Expand Down Expand Up @@ -283,10 +289,29 @@ def _export(

finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

def get_onnx_path(
self,
prefill_only: Optional[bool] = False,
specializations: Optional[List[Dict[str, int]]] = None,
offload_pt_weights: Optional[bool] = True,
):
kwargs = {"offload_pt_weights": offload_pt_weights}
if prefill_only:
if self.prefill_onnx_path is None:
kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")})
self.export(**kwargs)
return self.prefill_onnx_path
else:
if self.onnx_path is None:
self.export(**kwargs)
return self.onnx_path

@dump_qconfig
def _compile(
self,
Expand All @@ -300,6 +325,8 @@ def _compile(
num_speculative_tokens: Optional[int] = None,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
prefill_only: Optional[str] = None,
offload_pt_weights: Optional[bool] = True,
**compiler_options,
) -> str:
"""
Expand All @@ -325,10 +352,9 @@ def _compile(

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
if onnx_path is None and self.onnx_path is None:
self.export()

onnx_path = Path(onnx_path or self.onnx_path)
onnx_path = Path(
onnx_path if onnx_path else self.get_onnx_path(prefill_only, specializations, offload_pt_weights)
)
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
Expand Down Expand Up @@ -390,6 +416,7 @@ def _compile(
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
"prefill_only": prefill_only,
}
compile_hash = hash_dict_params(compile_hash_params)

Expand Down
3 changes: 2 additions & 1 deletion QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with the active adapter to ONNX format.

Expand Down Expand Up @@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
**kwargs,
)

def compile(
Expand Down
3 changes: 2 additions & 1 deletion QEfficient/peft/lora/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _init_adapter_model(self):
# load_weight to model
self._load_adapter_weights_to_model()

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.

Expand Down Expand Up @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
output_names,
dynamic_axes,
export_dir=export_dir,
**kwargs,
)

def generate(
Expand Down
31 changes: 31 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def write_only(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
_, _, ctx_len, _ = self.key_cache[layer_idx].shape
if is_sliding_layer:
kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1)
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
else:
kv_position_ids = position_ids

self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@
# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}

# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
Expand Down
Loading
Loading