diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 33c6f5588..0efccd41b 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -6,17 +6,21 @@ # ----------------------------------------------------------------------------- import os -import warnings - -import QEfficient.utils.model_registery # noqa: F401 -from QEfficient.utils import custom_format_warning -from QEfficient.utils.logging_utils import logger +# ----------------------------------------------------------------------------- # # For faster downloads via hf_transfer # This code is put above import statements as this needs to be executed before # hf_transfer is imported (will happen on line 15 via leading imports) os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# DO NOT ADD ANY CODE ABOVE THIS LINE +# Please contact maintainers if you must edit this file above this line. +# ----------------------------------------------------------------------------- # # Placeholder for all non-transformer models registered in QEfficient +import warnings # noqa: I001 + +import QEfficient.utils.model_registery # noqa: F401 +from QEfficient.utils import custom_format_warning +from QEfficient.utils.logging_utils import logger # custom warning for the better logging experience diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..c2c5a7212 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -57,6 +57,8 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) + self.prefill_enabled = False + self.prefill_onnx_path: Optional[str] = None self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None @@ -179,6 +181,7 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + prefill_only: Optional[bool] = False, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -207,7 +210,10 @@ def _export( # Return early if ONNX already exists if onnx_path.is_file(): - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path # check if the model is in meta state or weights are offloaded @@ -283,10 +289,29 @@ def _export( finally: shutil.rmtree(tmp_onnx_dir, ignore_errors=True) - - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path + def get_onnx_path( + self, + prefill_only: Optional[bool] = False, + specializations: Optional[List[Dict[str, int]]] = None, + offload_pt_weights: Optional[bool] = True, + ): + kwargs = {"offload_pt_weights": offload_pt_weights} + if prefill_only: + if self.prefill_onnx_path is None: + kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) + self.export(**kwargs) + return self.prefill_onnx_path + else: + if self.onnx_path is None: + self.export(**kwargs) + return self.onnx_path + @dump_qconfig def _compile( self, @@ -300,6 +325,8 @@ def _compile( num_speculative_tokens: Optional[int] = None, enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, + prefill_only: Optional[str] = None, + offload_pt_weights: Optional[bool] = True, **compiler_options, ) -> str: """ @@ -325,10 +352,9 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - if onnx_path is None and self.onnx_path is None: - self.export() - - onnx_path = Path(onnx_path or self.onnx_path) + onnx_path = Path( + onnx_path if onnx_path else self.get_onnx_path(prefill_only, specializations, offload_pt_weights) + ) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): @@ -390,6 +416,7 @@ def _compile( "mdp_ts_num_devices": mdp_ts_num_devices, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, + "prefill_only": prefill_only, } compile_hash = hash_dict_params(compile_hash_params) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index 592c0c1d3..e5b8bd185 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with the active adapter to ONNX format. @@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str: export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, + **kwargs, ) def compile( diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 8196cd769..8ff8335f5 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names, dynamic_axes, export_dir=export_dir, + **kwargs, ) def generate( diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 853567be9..18a15e480 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -594,6 +594,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache + def write_only( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + _, _, ctx_len, _ = self.key_cache[layer_idx].shape + if is_sliding_layer: + kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1) + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 5337b44f5..47059d8dc 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -188,6 +188,9 @@ # This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} +# This is for supporting different modelling classes specially written for prefill-only model +SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"} + # Define a transformers layers to QEff layers dictionary # While onboarding new models make sure to add the new layer maps to this dictionary. TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = { diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 62bc849b7..946a2851c 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +import math +import os from typing import Callable, Optional, Union import torch @@ -32,6 +34,7 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils import constants from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger class QEffGptOssExperts(GptOssExperts): @@ -42,8 +45,10 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) -class QEffGptOssMLP(GptOssMLP): - def alt_forward(self, hidden: torch.Tensor): +class QEffPrefillOnlyGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + if os.environ.get("NUM_FFN_BLOCKS", None) is not None: + return self.blocked_ffn_forward(hidden) B, S, H = hidden.shape T = B * S hidden = hidden.view(T, H) @@ -95,6 +100,169 @@ def alt_forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits + def blocked_ffn_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + # Gate and Up projections + gate = (tgb @ W_g) + b_g # [T, I] + up = (tgb @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out_block = (intermediate @ W_d) + b_d # [T, H] + + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + + wg_col_shape = W_g.shape[1] + wg_num_blocks = math.ceil(wg_col_shape / 128) + last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128 + + intermediates = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:] + cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:] + else: + cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128] + cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128] + + cur_gate = cur_gate.clamp(min=None, max=self.experts.limit) + cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit) + cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha) + cur_intermediate = (cur_up + 1) * cur_glu + intermediates.append(cur_intermediate) + + intermediate = torch.cat(intermediates, dim=-1) + + downs = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:]) + else: + downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128]) + + down_out_block = torch.cat(downs, dim=1) + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + +class QEffGptOssMLP(GptOssMLP): # ------------------- Gather based, weights as activation approach --------------- def forward_weights_as_activation(self, hidden_states): bs, seq_len, _ = hidden_states.shape @@ -142,7 +310,6 @@ def forward_weights_as_activation(self, hidden_states): # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- def forward(self, hidden_states): - # print("Seperate Split, Up, Gate Projections") bs, seq_len, _ = hidden_states.shape hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) @@ -404,6 +571,134 @@ def eager_attention_forward( return attn_output, attn_weights +def eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + q_block = query[:, :, qi : qi + real_q_len, :] + scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, value_states) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +class QEffPrefillOnlyGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if self.sliding_window is not None: + sliding_window_len = past_key_value.sliding_window_len + short_read_idx = torch.arange(sliding_window_len) + read_idx = short_read_idx + torch.where( + position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 + ) + # This is a trick to export with NUM_BLOCKS position_ids.max(), 0, read_idx) + k_cache = key_states[:, :, read_idx, :] + v_cache = value_states[:, :, read_idx, :] + else: + k_cache, v_cache = key_states, value_states + _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward_blocked + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -428,8 +723,9 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -505,7 +801,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores - # alth, _ = self.mlp.alt_forward(hidden_states) hidden_states = hidden_states.reshape(residual.shape) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -519,6 +814,98 @@ def forward( return outputs +class QEffPrefillOnlyGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.max_cache_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + class QEffGptOssModel(GptOssModel): def forward( self, @@ -571,7 +958,6 @@ def forward( ) hidden_states = inputs_embeds - # position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -714,9 +1100,15 @@ def get_specializations( batch_size: int, prefill_seq_len: int, ctx_len: int, + **kwargs, ): batch_size = batch_size if batch_size else 1 prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN + if kwargs.get("prefill_only") and ctx_len != prefill_seq_len: + ctx_len = prefill_seq_len + logger.warning( + f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model" + ) ctx_len = ctx_len if ctx_len else constants.CTX_LEN specializations = [ diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..797277fe5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import os import warnings from pathlib import Path from time import perf_counter @@ -37,12 +38,17 @@ get_compilation_dims, ) from QEfficient.generation.vlm_generation import VisionLanguageGeneration -from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + SPECIALIZED_PREFILL_ONLY_MODEL_ARCH, +) from QEfficient.transformers.models.pytorch_transforms import ( CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyTransform, + RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -314,7 +320,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -594,7 +600,7 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the vision encoder component to ONNX format. @@ -736,7 +742,7 @@ def __init__(self, model, **kwargs): self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the language decoder component to ONNX format. @@ -2113,11 +2119,20 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + def prefill(self, enable: Optional[bool] = True): + if enable: + self.model, tf = PrefillOnlyTransform.apply(self.model) + self.prefill_enabled = True + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + self.prefill_enabled = False + def __init__( self, model: nn.Module, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, **kwargs, ): """ @@ -2163,6 +2178,7 @@ def __init__( ) # Set use_cache=True to get KV values as output during ONNX export model.config.use_cache = True + setattr(model.config, "max_seq_len_cached", max_seq_len_cached) super().__init__(model, qaic_config=qaic_config, **kwargs) self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching @@ -2171,6 +2187,7 @@ def __init__( self.is_tlm = transformed self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.hash_params["max_seq_len_cached"] = max_seq_len_cached # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -2207,6 +2224,7 @@ def from_pretrained( pretrained_model_name_or_path, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, *args, **kwargs, ): @@ -2280,6 +2298,7 @@ def from_pretrained( continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + max_seq_len_cached=max_seq_len_cached, **kwargs, ) @@ -2295,7 +2314,51 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Optional[int] = None) -> int: + num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) + if num_q_blocks is None: + block_size = 128 + if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Or set `NUM_BLOCKS` ENV variable" + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + num_q_blocks = prefill_seq_len // block_size + logger.warning( + f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override" + ) + os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks) + num_q_blocks = int(num_q_blocks) + + num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) + num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks + min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks + if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0: + raise ValueError( + f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but," + "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." + ) + + self.prefill(True) + self.hash_params["prefill_only"] = True + self.hash_params["num_blocks"] = num_q_blocks + self.hash_params["num_ffn_blocks"] = num_ffn_blocks + return ( + min_seq_len + if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + ) + + def export( + self, + export_dir: Optional[str] = None, + prefill_only: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + offload_pt_weights: Optional[bool] = True, + **kwargs, + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2317,6 +2380,18 @@ def export(self, export_dir: Optional[str] = None) -> str: bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + if prefill_only: + assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" + seq_len = ( + self.get_seq_len_and_handle_specialized_prefill_model(prefill_seq_len) + if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH + else seq_len + ) + else: + self.prefill(False) + self.hash_params.pop("prefill_only", None) + self.hash_params.pop("num_blocks", None) + kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len ) @@ -2394,12 +2469,13 @@ def export(self, export_dir: Optional[str] = None) -> str: output_names=output_names, dynamic_axes=dynamic_axes, ) - return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, + offload_pt_weights=offload_pt_weights, + prefill_only=prefill_only, ) def get_sampling_inputs_and_outputs( @@ -2488,6 +2564,7 @@ def build_prefill_specialization( batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the prefill phase. @@ -2515,6 +2592,7 @@ def build_prefill_specialization( batch_size=1 if self.continuous_batching else batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + **kwargs, )[0] else: spec = { @@ -2603,6 +2681,7 @@ def compile( mxint8_kv_cache: bool = False, num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, + offload_pt_weights: Optional[bool] = True, **compiler_options, ) -> str: """ @@ -2705,6 +2784,9 @@ def compile( ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") + if kv_cache_batch_size and prefill_only is not None and prefill_only: + logger.warning("kv_cache_batch_size will be ignored as prefill_only is set to True") + # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size @@ -2740,7 +2822,6 @@ def compile( for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -2754,6 +2835,8 @@ def compile( num_speculative_tokens=num_speculative_tokens, aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, + prefill_only=prefill_only, + offload_pt_weights=offload_pt_weights, **compiler_options, ) @@ -2943,7 +3026,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3307,7 +3390,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..b3e60a8b9 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -259,6 +259,9 @@ QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, + QEffPrefillOnlyGptOssAttention, + QEffPrefillOnlyGptOssMLP, + QEffPrefillOnlyGptOssModel, ) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, @@ -630,6 +633,18 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class PrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + } + + +class RevertPrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = {v: k for k, v in PrefillOnlyTransform._module_mapping.items()} + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index dfadc00ef..dc2308e99 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -5,6 +5,6 @@ # # ----------------------------------------------------------------------------- -from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers -__all__ = ["replace_transformers_quantizers"] +__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"] diff --git a/examples/gpt_oss_disagg_mode.py b/examples/gpt_oss_disagg_mode.py new file mode 100644 index 000000000..ee03f573a --- /dev/null +++ b/examples/gpt_oss_disagg_mode.py @@ -0,0 +1,136 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +all_outputs = [] +# Run prefill +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 256 +CTX_LEN = 256 +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# Initialize variables specific to request +# Calculate the max generation length. +max_gen_len = CTX_LEN - position_ids.max() +generation_len = max_gen_len + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +config = qeff_model.model.config +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(config.num_hidden_layers): + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, +) +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, +) + +prefill_session = QAICInferenceSession(prefill_qpc_path) + +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) +prefill_session.set_buffers({"logits": logits_out_placeholder}) +inputs.pop("past_key_values") +inputs = {k: v.detach().numpy() for k, v in inputs.items()} +st = time.time() +qpc_out = prefill_session.run(inputs) +print(f"time for prefill_run={time.time() - st} sec\n") + +decode_session = QAICInferenceSession(decode_qpc_path) +decode_session.set_buffers({"logits": logits_out_placeholder}) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +print("pos_id for decodee", decode_inputs["position_ids"]) + +all_outputs.append(decode_inputs["input_ids"][0][0]) +for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] +) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +st = time.time() +for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") +print(all_outputs) +print(tokenizer.decode(all_outputs)) diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index d9d391d47..bf6f82ce3 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -42,7 +42,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic && - pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml && + pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml && junitparser merge tests/tests_log1.xml tests/tests_log.xml && deactivate" ''' diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 00a4216b7..46b33c60b 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( # export start = perf_counter() - qeff_model.export(export_dir=tmp_path) + onnx_path = qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_0 = end - start model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash) @@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( assert export_time_1 < export_time_0 # test compile - qeff_model.compile(prefill_seq_len=32, ctx_len=64) + qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64) assert Path(qeff_model.qpc_path).is_dir() assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py index cc94467db..c3bb2f140 100644 --- a/tests/peft/test_peft_model.py +++ b/tests/peft/test_peft_model.py @@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path): _, lora_model = create_peft_model(base_config, adapter_config) qeff_model = QEffAutoPeftModelForCausalLM(lora_model) - qeff_model.export(tmp_path) + onnx_path = qeff_model.export(tmp_path) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_0 = end - start @@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con ) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_1 = end - start assert compile_time_1 < 0.01 * compile_time_0 diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py new file mode 100644 index 000000000..67ee48944 --- /dev/null +++ b/tests/transformers/models/test_disagg_mode.py @@ -0,0 +1,104 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt2 = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +prompt1 = "Once upon a time" + +prompts = [prompt1, prompt2] + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + qeff_out = qeff_model.model(**inputs) + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + qpc_out = prefill_session.run(inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index 0810ac6ba..925af8b3a 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -17,7 +17,7 @@ from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.hash_utils import hash_dict_params -configs = [ +test_configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params ("gpt2", 256, 2, 4, 128, 512, 127, {}), ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), @@ -36,30 +36,43 @@ ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] -configs = [ - AutoConfig.for_model( - model_name, - max_position_embeddings=max_position_embeddings, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - vocab_size=vocab_size, - **additional_params, - ) - for ( - model_name, - max_position_embeddings, - num_hidden_layers, - num_attention_heads, - hidden_size, - intermediate_size, - vocab_size, - additional_params, - ) in configs +test_prefill_only_specialized_models_configs = [ + ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}), ] + + +def get_auto_config_from_test_config(configs): + auto_configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs + ] + return auto_configs + + +configs = get_auto_config_from_test_config(test_configs) config_ids = [x.model_type for x in configs] +prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs) +prefill_only_config_ids = [x.model_type for x in prefill_only_configs] + model_kwargs = {"attn_implementation": "eager"} @@ -154,10 +167,10 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): hash_params["peft_config"] = None hash_params["applied_transform_names"] = qeff_model._transform_names() hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["max_seq_len_cached"] = None hash_params["qaic_config"] = None # Create parameters separately for hash creation - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS @@ -209,6 +222,24 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): assert manual_hash == qeff_model.export_hash +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids) +def test_prefill_only_specialized_models(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + if cb: + with pytest.raises(AssertionError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + else: + with pytest.raises(ValueError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False) + first_export_hash = qeff_model.export_hash + qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False) + second_export_hash = qeff_model.export_hash + assert first_export_hash != second_export_hash + + @pytest.fixture def tmp_cache(tmp_path, monkeypatch): monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path)