diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h index 5ce2f4e1daf..795de9a599a 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h @@ -59,6 +59,30 @@ __forceinline__ __device__ float tanh_opt(float x) #endif } +template +struct Relu2 +{ + static bool const kIsHeavy = false; + + CUTLASS_HOST_DEVICE + T operator()(T threshold, T value) const + { + ReLu relu_op; + multiplies mul; + T val = relu_op(threshold, value); + return mul(val, val); + } + + CUTLASS_HOST_DEVICE + T operator()(T value) const + { + ReLu relu_op; + multiplies mul; + T val = relu_op(value); + return mul(val, val); + } +}; + } // namespace thread } // namespace epilogue } // namespace cutlass diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h index 646be2575ca..55226c68960 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h @@ -28,6 +28,7 @@ enum class ActivationType Swiglu, Geglu, SwigluBias, + Relu2, Identity, InvalidType }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index 2c0d1a94a53..477634cd3c0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -954,6 +954,7 @@ void MoeGemmRunner::moeGemmBiasAct( case ActivationType::Identity: runGemm(inputs, hopper_inputs); break; case ActivationType::Swiglu: runGemm(inputs, hopper_inputs); break; case ActivationType::Geglu: runGemm(inputs, hopper_inputs); break; + case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break; case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; default: TLLM_THROW("Invalid activation type."); break; } diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 0fb56f3893d..383ad87e988 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -2307,6 +2307,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 decltype(block_scaling_type)::value>, // Geglu &doActivationKernel, // SwigluBias + &doActivationKernel, + decltype(block_scaling_type)::value>, // Relu2 &doActivationKernel, decltype(block_scaling_type)::value> // Identity diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index fbed602d464..2cc1fb27977 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, torch::optional> const& profile_ids, - torch::optional const& unpadded_hidden_size, torch::optional const& num_valid_tokens, - torch::optional const& out_tensor) + torch::optional const& activation_type, torch::optional const& unpadded_hidden_size, + torch::optional const& num_valid_tokens, torch::optional const& out_tensor) { std::lock_guard lock(mMutex); // Free the profile workspace to save memory @@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0], "fc1_expert_weights and fc2_expert_weights must have the same number of experts."); + ActivationType base_activation_type = activation_type.has_value() + ? static_cast(activation_type.value()) + : ActivationType::Swiglu; if (mUseINT8WoqPerChannel) { // Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights: @@ -337,8 +340,16 @@ class FusedMoeRunner : public torch::CustomClassHolder } else { - TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, - "fc1_expert_weights inter size must be fc2_expert_weights inter size."); + if (isGatedActivation(base_activation_type)) + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, + "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + } + else + { + TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier, + "fc1_expert_weights inter size must be equal to fc2_expert_weights inter size."); + } } int experts_per_token = token_selected_experts.sizes()[1]; @@ -375,7 +386,7 @@ class FusedMoeRunner : public torch::CustomClassHolder int const num_experts_on_rank = fc2_expert_weights.sizes()[0]; auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); - ActivationType base_activation_type = ActivationType::Swiglu; + if (swiglu_alpha.has_value()) { CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); @@ -474,8 +485,8 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, torch::optional> const& profile_ids, - torch::optional const& unpadded_hidden_size, torch::optional const& num_valid_tokens, - torch::optional const& out_tensor) + torch::optional const& activation_type, torch::optional const& unpadded_hidden_size, + torch::optional const& num_valid_tokens, torch::optional const& out_tensor) { std::lock_guard lock(mMutex); @@ -541,7 +552,9 @@ class FusedMoeRunner : public torch::CustomClassHolder auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank); - ActivationType base_activation_type = ActivationType::Swiglu; + ActivationType base_activation_type = activation_type.has_value() + ? static_cast(activation_type.value()) + : ActivationType::Swiglu; if (swiglu_alpha.has_value()) { CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float); @@ -652,7 +665,8 @@ class FusedMoeRunner : public torch::CustomClassHolder torch::optional const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx, - int64_t const profile_id, bool const do_preparation, int64_t const unpadded_hidden_size) + int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int, + int64_t const unpadded_hidden_size) { std::lock_guard lock(mMutex); @@ -661,6 +675,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { return; } + ActivationType activation_type = static_cast(activation_type_int); int64_t const num_rows = input.sizes()[0]; int64_t hidden_size = fc2_expert_weights.sizes()[1]; @@ -715,14 +730,14 @@ class FusedMoeRunner : public torch::CustomClassHolder tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype), tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast(top_k), hidden_size, unpadded_hidden_size > 0 ? unpadded_hidden_size : hidden_size, inter_size, group_size, - ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, + activation_type, USE_BIAS, USE_LORA, min_latency_mode, /*need_weights*/ false, parallelism_config, enable_alltoall); #else mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile, tensorrt_llm::runtime::TorchUtils::dataType(activation_dtype), tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype), tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast(top_k), - hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode, + hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, min_latency_mode, /*need_weights*/ false, parallelism_config); #endif diff --git a/examples/auto_deploy/.gitignore b/examples/auto_deploy/.gitignore index 9836a37fc88..a0ef9cd4947 100644 --- a/examples/auto_deploy/.gitignore +++ b/examples/auto_deploy/.gitignore @@ -4,3 +4,6 @@ benchmark_results.json *.png # ignore config files that users might put here for debugging *.yaml +!nano_v3.yaml +!nano_v3_accuracy.yaml +!nano_v3_bench.yaml diff --git a/examples/auto_deploy/nano_v3.yaml b/examples/auto_deploy/nano_v3.yaml new file mode 100644 index 00000000000..411037cc175 --- /dev/null +++ b/examples/auto_deploy/nano_v3.yaml @@ -0,0 +1,23 @@ +runtime: trtllm +compile_backend: torch-cudagraph +max_batch_size: 384 +max_seq_len: 65536 # tunable +enable_chunked_prefill: true +attn_backend: flashinfer +model_factory: AutoModelForCausalLM +skip_loading_weights: false +free_mem_ratio: 0.9 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] +kv_cache_config: + # disable kv_cache reuse since not supported for hybrid/ssm models + enable_block_reuse: false +transforms: + detect_sharding: + sharding_source: ['factory', 'heuristic'] + sharding_dims: ['ep', 'bmm'] + # tunable mamba cache dtype + # --> use float32 for accuracy and default (null) for speed + insert_cached_ssm_attention: + cache_config: + # mamba_dtype: float32 + mamba_dtype: null diff --git a/examples/auto_deploy/nano_v3_accuracy.yaml b/examples/auto_deploy/nano_v3_accuracy.yaml new file mode 100644 index 00000000000..4b848977120 --- /dev/null +++ b/examples/auto_deploy/nano_v3_accuracy.yaml @@ -0,0 +1,23 @@ +runtime: trtllm +compile_backend: torch-cudagraph +max_batch_size: 128 +max_seq_len: 204800 +enable_chunked_prefill: true +attn_backend: flashinfer +model_factory: AutoModelForCausalLM +skip_loading_weights: false +free_mem_ratio: 0.9 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128] +kv_cache_config: + # disable kv_cache reuse since not supported for hybrid/ssm models + enable_block_reuse: false +transforms: + detect_sharding: + sharding_source: ['factory', 'heuristic'] + sharding_dims: ['ep', 'bmm'] + # tunable mamba cache dtype + # --> use float32 for accuracy and default (null) for speed + insert_cached_ssm_attention: + cache_config: + mamba_dtype: float32 + # mamba_dtype: null diff --git a/examples/auto_deploy/nano_v3_bench.yaml b/examples/auto_deploy/nano_v3_bench.yaml new file mode 100644 index 00000000000..fc9b04e2640 --- /dev/null +++ b/examples/auto_deploy/nano_v3_bench.yaml @@ -0,0 +1,23 @@ +runtime: trtllm +compile_backend: torch-cudagraph +max_batch_size: 384 # tunable +max_seq_len: 65536 # tunable +enable_chunked_prefill: true +attn_backend: flashinfer +model_factory: AutoModelForCausalLM +skip_loading_weights: false +free_mem_ratio: 0.9 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384] +kv_cache_config: + # disable kv_cache reuse since not supported for hybrid/ssm models + enable_block_reuse: false +transforms: + detect_sharding: + sharding_source: ['factory', 'heuristic'] + sharding_dims: ['ep', 'bmm'] + # tunable mamba cache dtype + # --> use float32 for accuracy and default (null) for speed + insert_cached_ssm_attention: + cache_config: + # mamba_dtype: float32 + mamba_dtype: null diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 1fb094e7e2e..4a98593c68b 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -175,7 +175,7 @@ def forward(self, *args, **kwargs) -> Any: # retrieve output from buffer, cut to batch size, and unflatten bs = args_batched[0].shape[0] - out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat] + out_flat = [o_b[:bs] for o_b in self._out_buffer_flat] return self._out_spec.unflatten(out_flat) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index b9d001c3d72..fabde3af980 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -165,6 +165,8 @@ transforms: ############################################################################################ # COMPILE MODEL ############################################################################################ + fuse_causal_conv_activation: + stage: compile compile_model: stage: compile run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 02f7001cff0..1de7050f422 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,10 +10,10 @@ """ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union import torch +from pydantic import BaseModel, ConfigDict, Field, field_validator from torch._ops import OpOverloadPacket from torch.fx import Node from torch.types import Number @@ -24,11 +24,39 @@ Constant = Union[int, float, str, None] -@dataclass -class CacheConfig: - """A dataclass to hold information how to configure the cache.""" +class CacheConfig(BaseModel): + """Cache configuration for attention-related dtypes.""" - dtype: Optional[torch.dtype] = None + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.") + mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.") + + @field_validator("dtype", "mamba_dtype", mode="before") + @classmethod + def _coerce_dtype(cls, value): + if value is None or isinstance(value, torch.dtype): + return value + if isinstance(value, str): + dtype = getattr(torch, value, None) + assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}" + return dtype + return value + + def __or__(self, other: "CacheConfig") -> "CacheConfig": + """Combine two CacheConfig objects field-wise using Python's `or` semantics. + + For each field, selects the first non-None value between `self` and `other`. + """ + if not isinstance(other, CacheConfig): + raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}") + merged_kwargs = {} + for field_name in type(self).model_fields.keys(): + merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name) + return CacheConfig(**merged_kwargs) class SequenceInfo: @@ -88,6 +116,7 @@ def __init__( page_size: int = 0, max_num_tokens: Optional[int] = None, vocab_size_padded: Optional[int] = None, + chunk_size: Optional[int] = None, ): """Initialize the SequenceInfo object. @@ -114,7 +143,7 @@ def __init__( self.max_batch_size = max_batch_size self.page_size = page_size if page_size > 0 else max_seq_len self.vocab_size_padded = vocab_size_padded - + self.chunk_size = chunk_size # NOTE (lucaslie): WAR to address issue when using flashinfer attention with # (max_batch_size, max_seq_len) input in trtllm runtime. # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 @@ -165,7 +194,7 @@ def __init__( "input_pos": torch.empty(self.max_batch_size, dtype=torch.int), "cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int), "pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int), - "slot_idx": torch.empty(self.max_batch_size, dtype=torch.int), + "slot_idx": torch.empty(self.max_batch_size, dtype=torch.long), # OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER "_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int), } @@ -175,7 +204,7 @@ def __init__( # NOTE: order of keys is relevant here! self._uncached_arg_names = ("input_ids", "position_ids") self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx") - self._cached_constants = ("page_size",) + self._cached_constants = ("page_size", "chunk_size") ############################################################################################ # EXTRA TENSOR FIELDS ###################################################################### diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 63a8c7b1547..f621539c06b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -162,6 +162,7 @@ def prepare_flashinfer_metadata( pages_per_seq: torch.Tensor, slot_idx: torch.Tensor, page_size: int, + chunk_size: int, ) -> List[torch.Tensor]: """Prepare metadata for flashinfer attention. @@ -213,7 +214,7 @@ def prepare_flashinfer_metadata( # As SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_flashinfer_metadata.register_fake def prepare_flashinfer_metadata_fake( - position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size ): seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py index dbf307f5d35..c819b46b07b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py @@ -11,6 +11,8 @@ import triton import triton.language as tl +from tensorrt_llm._utils import nvtx_range + @triton.jit def _write_zeros_to_output( @@ -304,6 +306,7 @@ def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict: } +@nvtx_range("triton_moe_pack_routed_tokens") def _pack_routed_tokens( topk_ids: torch.Tensor, M: int, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py index a14d0f436e5..af55193af1c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py @@ -1,13 +1,17 @@ import torch +from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType + @torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=()) -def trtllm_fused_moe( +def trtllm_moe_fused( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", ) -> torch.Tensor: x_shape = x.shape x = x.view(-1, x_shape[-1]) @@ -16,30 +20,219 @@ def trtllm_fused_moe( selected_experts = selected_experts.to(torch.int32) quant_scales = [] + # Determine activation type + mlp_style = mlp_style.lower() + act_fn = act_fn.lower() + + activation_type = ActivationType.Swiglu + if mlp_style == "gated_mlp": + # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) + if act_fn == "silu": + # activation_type = ActivationType.Silu + activation_type = ActivationType.Swiglu # need to fix this in trtllm + else: + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + if act_fn == "relu2": + activation_type = ActivationType.Relu2 + else: + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + return torch.ops.trtllm.fused_moe( x, selected_experts, routing_weights, - w3_w1_stacked_weight, - None, # w3_w1_stacked_bias - w2_stacked_weight, - None, # w2_stacked_bias - x.dtype, - quant_scales, - tp_size=1, - tp_rank=0, - ep_size=1, - ep_rank=0, - enable_alltoall=False, + fc1_expert_weights=w3_w1_stacked_weight, + fc1_expert_biases=None, + fc2_expert_weights=w2_stacked_weight, + fc2_expert_biases=None, + output_dtype=x.dtype, + quant_scales=quant_scales, + activation_type=activation_type, )[0].view(x_shape) -@trtllm_fused_moe.register_fake -def trtllm_fused_moe( +@trtllm_moe_fused.register_fake +def trtllm_moe_fused_fake( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + return torch.empty_like(x) + + +# Todo: refactor this repeating code block +def _quantize_fp8(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Quantize tensor to FP8 with clamping (matches torch_quant_fp8_linear).""" + FP8_MIN = torch.finfo(torch.float8_e4m3fn).min + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + return (x / scale).clamp(FP8_MIN, FP8_MAX).to(torch.float8_e4m3fn) + + +def _validate_mlp_style_and_act_fn(mlp_style: str, act_fn: str) -> None: + supported_combinations = { + "gated_mlp": ["silu"], + "mlp": ["relu2"], + } + supported_act_fns = [ + act_fn for act_fn_list in supported_combinations.values() for act_fn in act_fn_list + ] + assert mlp_style in supported_combinations.keys(), ( + f"Unknown mlp_style '{mlp_style}'. Use {supported_combinations.keys()}." + ) + assert act_fn in supported_act_fns, f"Unknown act_fn '{act_fn}'. Use {supported_act_fns}." + assert act_fn in supported_combinations[mlp_style], ( + f"Unsupported combination: mlp_style='{mlp_style}', act_fn='{act_fn}'. " + f"Supported combinations: {supported_combinations}" + ) + + +@torch.library.custom_op("auto_deploy::trtllm_quant_fp8_moe_fused", mutates_args=()) +def trtllm_quant_fp8_moe_fused( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, # [E, I, H] stacked FP8 weights + w2_weight: torch.Tensor, # [E, H, I] stacked FP8 weights + w3_weight: torch.Tensor, # [E, I, H] for gated_mlp, unused for mlp + w1_input_scale: torch.Tensor, # [E] stacked input scales + w2_input_scale: torch.Tensor, # [E] stacked input scales + w3_input_scale: torch.Tensor, # [E] or unused + w1_weight_scale: torch.Tensor, # [E] stacked weight scales + w2_weight_scale: torch.Tensor, # [E] stacked weight scales + w3_weight_scale: torch.Tensor, # [E] or unused + mlp_style: str = "gated_mlp", + act_fn: str = "silu", +) -> torch.Tensor: + """ + TensorRT-LLM Cutlass FP8 W8A8 MoE for gated and non-gated MLP. + Parameters: + x: BF16/FP16 input tensor of shape (B, H) or (B, S, H) + selected_experts: Expert indices (B*S, TOP_K) + routing_weights: Routing weights (B*S, TOP_K) + w1_weight: FP8 w1 weights [E, I, H] + w2_weight: FP8 w2 weights [E, H, I] + w3_weight: FP8 w3 weights [E, I, H] (for gated_mlp) + w1_input_scale: Input scales for w1 [E] + w2_input_scale: Input scales for w2 [E] + w3_input_scale: Input scales for w3 [E] + w1_weight_scale: Weight scales for w1 [E] + w2_weight_scale: Weight scales for w2 [E] + w3_weight_scale: Weight scales for w3 [E] + mlp_style: "gated_mlp" or "mlp" + act_fn: "silu" for gated_mlp, "relu2" for mlp + + Non-Gated MLP: + activation_fn(expert_inputs @ w1_expert.t())@ w2_expert.t() + + Gated MLP: + activation_fn(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) @ w2_expert.t() + """ + + _validate_mlp_style_and_act_fn(mlp_style, act_fn) + + # Store original shape and flatten to 2D + x_shape = x.shape + x2d = x.view(-1, x_shape[-1]) + # Quantize input + x_q_fp8 = _quantize_fp8(x2d, w1_input_scale[0]) + + # Scales are stored in float32 + w1_weight_scale = w1_weight_scale.to(torch.float32) + w2_weight_scale = w2_weight_scale.to(torch.float32) + w1_input_scale = w1_input_scale.to(torch.float32)[0] + w2_input_scale = w2_input_scale.to(torch.float32)[0] + + # Prepare quant_scales for TensorRT-LLM FP8 format: + # [gemm1_dequant_scale, gemm2_act_quant_scale, gemm2_dequant_scale, gemm1_input_dequant_scale] + # For gated MLP: + # - gemm1_dequant_scale: w1_weight_scale * w1_input_scale (combined for w1 and w3) + # - gemm2_act_quant_scale: 1 / w2_input_scale + # - gemm2_dequant_scale: w2_weight_scale * w2_input_scale + # - gemm1_input_dequant_scale: w1_input_scale + + # Compute combined scales + gemm1_dequant = (w1_weight_scale * w1_input_scale).contiguous().squeeze() + gemm2_act_quant = (1.0 / w2_input_scale).contiguous().to(torch.float32) + gemm2_dequant = (w2_weight_scale * w2_input_scale).contiguous().squeeze() + gemm1_input_dequant = w1_input_scale.contiguous() + + assert gemm1_dequant.ndim == 1, "gemm1_dequant must be 1D" + assert gemm2_dequant.ndim == 1, "gemm2_dequant must be 1D" + quant_scales = [gemm1_dequant, gemm2_act_quant, gemm2_dequant, gemm1_input_dequant] + + # Ensure contiguous tensors + selected_experts = selected_experts.int().contiguous() + routing_weights = routing_weights.contiguous() + + # Todo: refactor this repeating code block + + # Determine activation type + mlp_style = mlp_style.lower() + act_fn = act_fn.lower() + + activation_type = ActivationType.Swiglu + if mlp_style == "gated_mlp": + # Gated MLP uses Silu: silu(x @ w1.T) * (x @ w3.T) + # For gated MLP, concatenate w1 and w3 + # TensorRT-LLM expects [w3, w1] concatenated + w3_w1_stacked = torch.cat([w3_weight, w1_weight], dim=1).contiguous() # [E, 2*I, H] + fc1_expert_weights = w3_w1_stacked + if act_fn == "silu": + # activation_type = ActivationType.Silu + activation_type = ActivationType.Swiglu # need to fix this in trtllm + else: + raise ValueError(f"Unsupported activation '{act_fn}' for gated_mlp. Use 'silu'.") + elif mlp_style == "mlp": + # For non-gated MLP with ReLU^2 + fc1_expert_weights = w1_weight.contiguous() + if act_fn == "relu2": + activation_type = ActivationType.Relu2 + else: + raise ValueError(f"Unsupported activation '{act_fn}' for mlp. Use 'relu2'.") + else: + raise ValueError(f"Unknown mlp_style '{mlp_style}'. Use 'gated_mlp' or 'mlp'.") + + # Note! Outputting Float8_e4m3fn directly is not currently supported + output = torch.ops.trtllm.fused_moe( + x_q_fp8, + selected_experts, + routing_weights, + fc1_expert_weights=fc1_expert_weights, + fc1_expert_biases=None, + fc2_expert_weights=w2_weight.contiguous(), + fc2_expert_biases=None, + output_dtype=x.dtype, + quant_scales=quant_scales, + activation_type=activation_type, + ) + + # Restore original shape + return output[0].view(x_shape) + + +@trtllm_quant_fp8_moe_fused.register_fake +def trtllm_quant_fp8_moe_fused_fake( + x: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor, + w1_weight: torch.Tensor, + w2_weight: torch.Tensor, + w3_weight: torch.Tensor, + w1_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + w3_input_scale: torch.Tensor, + w1_weight_scale: torch.Tensor, + w2_weight_scale: torch.Tensor, + w3_weight_scale: torch.Tensor, + mlp_style: str, + act_fn: str, ) -> torch.Tensor: return torch.empty_like(x) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 014f8cc7e6b..f47bb0bdcb3 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -61,6 +61,7 @@ def cuda_causal_conv_prepare_metadata( pages_per_seq: torch.Tensor, slot_idx: torch.Tensor, page_size: int, + chunk_size: int, ) -> List[torch.Tensor]: """Prepare metadata for cached causal conv (CUDA backend). @@ -81,7 +82,7 @@ def cuda_causal_conv_prepare_metadata( @cuda_causal_conv_prepare_metadata.register_fake def cuda_causal_conv_prepare_metadata_fake( - position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size ): seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) @@ -112,6 +113,7 @@ def _cuda_cached_causal_conv1d( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ) -> torch.Tensor: """Flattened cached causal conv that respects slot-indexed state caches (CUDA backend). @@ -175,26 +177,26 @@ def _cuda_cached_causal_conv1d( cache_indices=cache_indices, has_initial_state=has_initial_state, conv_states=conv_state_cache, - activation=None, + activation=activation, pad_slot_id=PAD_SLOT_ID, ) # (dim, total_prefill_tokens) # Scatter outputs back to y y_prefill = y_varlen.transpose(0, 1) # [total_prefill_tokens, C_out] - y_flat[:total_prefill_tokens].copy_(y_prefill.to(y_flat.dtype)) + y_flat[:total_prefill_tokens].copy_(y_prefill) # DECODE: batch update for single-token sequences if num_decode > 0: - # Use true start offsets for decode tokens (tail after prefills) - decode_idx = seq_start[num_prefill:].to(torch.long) - x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in] + x_decode = inp_flat[ + total_prefill_tokens : total_prefill_tokens + num_decode + ] # [num_decode, C_in] y_dec = causal_conv1d_update( x_decode, # [batch, dim] conv_state_cache, w2d, bias, - activation=None, + activation=activation, cache_seqlens=None, conv_state_indices=slot_idx[num_prefill:].to(torch.int32), pad_slot_id=PAD_SLOT_ID, @@ -202,10 +204,10 @@ def _cuda_cached_causal_conv1d( if y_dec.dim() == 3: y_dec = y_dec.squeeze(-1) - y_flat.index_copy_(0, decode_idx, y_dec.to(y_flat.dtype)) + y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec) # Custom op must not return an alias of any input; return a fresh tensor - return y.contiguous().clone() + return y @_cuda_cached_causal_conv1d.register_fake @@ -227,6 +229,7 @@ def _cuda_cached_causal_conv1d_fake( dilation: int, groups: int, padding_mode: str, + activation: Optional[str], ): return torch.empty( input.shape[0], input.shape[1], weight.shape[0], device=input.device, dtype=input.dtype @@ -293,4 +296,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - return [stride, padding, dilation, groups, padding_mode] + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index a204c559f00..6f0059d250d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -355,4 +355,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: stride, padding, dilation, groups, padding_mode = extract_op_args( source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode" ) - return [stride, padding, dilation, groups, padding_mode] + # None is for activation parameter, which may not exist in the source node (added by fusion later) + return [stride, padding, dilation, groups, padding_mode, None] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index ccd24e7ec00..79c68c2aac7 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -120,6 +120,7 @@ def _torch_ssm_prepare_metadata( pages_per_seq: torch.Tensor, slot_idx: torch.Tensor, page_size: int, + chunk_size: int, ) -> List[torch.Tensor]: """Prepare metadata for cached SSM transform. @@ -143,7 +144,7 @@ def _torch_ssm_prepare_metadata( @_torch_ssm_prepare_metadata.register_fake def _torch_ssm_prepare_metadata_fake( - position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size ): # Use the same sanitization logic to determine sizes in fake mode seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) @@ -347,6 +348,9 @@ def get_cache_initializers( # Fallback: assume last dim is n_groups * state_size and choose a minimal positive size ssm_state_size = max(1, B_fake.shape[-1]) + # extract ssm_state_dtype from cache_config or hs_fake + ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype + def _get_ssm_cache(si: SequenceInfo): return torch.empty( si.max_batch_size, @@ -354,7 +358,7 @@ def _get_ssm_cache(si: SequenceInfo): head_dim, ssm_state_size, device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, + dtype=ssm_state_dtype, ) return {"ssm_state_cache": _get_ssm_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 64b62419162..fb97eecee51 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -24,6 +24,94 @@ ) +@torch.library.custom_op("auto_deploy::triton_ssm_prepare_metadata", mutates_args=()) +def _triton_ssm_prepare_metadata( + position_ids: torch.Tensor, + seq_len: torch.Tensor, + input_pos: torch.Tensor, + cache_loc: torch.Tensor, + pages_per_seq: torch.Tensor, + slot_idx: torch.Tensor, + page_size: int, + chunk_size: int, +) -> List[torch.Tensor]: + """Prepare metadata for cached SSM transform. + + Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). + """ + # Determine number of active sequences and compute seq_start boundaries + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) + num_seq = len(seq_len_sanitized) + + seq_start = torch.zeros_like(seq_len_sanitized) + if num_seq > 1: + seq_start[1:] = torch.cumsum(seq_len_sanitized[:-1], 0) + + # Truncate slot indices to match active sequences + slot_idx_sanitized = slot_idx[:num_seq].clone().to(torch.long) + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/8170): update torch + # reference implementation to support chunked prefill. + use_initial_states = input_pos > 0 + + device = position_ids.device + + chunk_indices = torch.zeros(num_seq, dtype=torch.int32, device=device) + chunk_offsets = torch.zeros(num_seq, dtype=torch.int32, device=device) + cu_seqlens = torch.zeros(num_seq + 1, dtype=torch.int32, device=device) + _, s = position_ids.shape[:2] + if s > 1: + # only compute chunk indices and offsets for prefill. + prefill_mask = seq_len_sanitized > 1 + num_prefill = int(prefill_mask.sum().item()) + num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item()) + num_decode = num_seq - num_prefill + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + torch.cumsum(seq_len_sanitized[:num_prefill].to(torch.int32), dim=0), + ], + dim=0, + ) + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size) + else: + num_prefill = 0 + num_prefill_tokens = 0 + num_decode = num_seq + batch_info_tensor = torch.tensor( + [num_prefill, num_prefill_tokens, num_decode], dtype=torch.int32 + ) # host tensor + + return ( + seq_len_sanitized, + seq_start, + slot_idx_sanitized, + use_initial_states, + cu_seqlens, + chunk_indices, + chunk_offsets, + batch_info_tensor, + ) + + +@_triton_ssm_prepare_metadata.register_fake +def _triton_ssm_prepare_metadata_fake( + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size +): + # Use the same sanitization logic to determine sizes in fake mode + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) + num_seq = len(seq_len_sanitized) + return ( + torch.empty_like(seq_len_sanitized), + torch.empty_like(seq_len_sanitized), + torch.empty(num_seq, dtype=torch.long, device=slot_idx.device), + torch.empty(num_seq, dtype=torch.bool, device=slot_idx.device), + torch.empty(num_seq + 1, dtype=torch.int32, device=slot_idx.device), # cu seqlens + torch.empty(num_seq, dtype=torch.int32, device=slot_idx.device), # chunk indices + torch.empty(num_seq, dtype=torch.int32, device=slot_idx.device), # chunk offsets + torch.empty(2, dtype=torch.int32), # batch info tensor + ) + + @torch.library.custom_op("auto_deploy::triton_cached_ssm", mutates_args={}) def _triton_cached_ssm( # INPUTS (dense but may be flattened across sequences) @@ -39,6 +127,10 @@ def _triton_cached_ssm( seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] use_initial_states: torch.Tensor, # [num_seq] + cu_seqlens: torch.Tensor, # [num_seq + 1] + chunk_indices: torch.Tensor, # [num_seq + 1] + chunk_offsets: torch.Tensor, # [num_seq + 1] + batch_info_tensor: torch.Tensor, # [2] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] # CONSTANTS @@ -51,8 +143,7 @@ def _triton_cached_ssm( - Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot. - Decode: batch single-token updates with selective_state_update and update states per slot. """ - b, s = hidden_states.shape[:2] - num_seq = seq_len.shape[0] + b, s, num_heads, head_dim = hidden_states.shape # Flatten tokens for indexing/scatter bs = b * s device = hidden_states.device @@ -64,39 +155,23 @@ def _triton_cached_ssm( y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format) y_flat = y.view(bs, *y.shape[2:]) - num_heads = hidden_states.shape[2] - head_dim = hidden_states.shape[3] ssm_state_size = B.shape[3] - if s == 1: - num_prefill = 0 - num_decode = num_seq - else: - prefill_mask = seq_len > 1 - num_prefill = int(prefill_mask.sum().item()) - num_decode = num_seq - num_prefill + [num_prefill, num_prefill_tokens, num_decode] = batch_info_tensor.tolist() # Prefill: concatenate tokens at the front and run combined scan if num_prefill > 0: - seq_len_prefill = seq_len[:num_prefill].to(torch.int32) - total_prefill_tokens = int(seq_len_prefill.sum().item()) + seq_len_prefill = seq_len[:num_prefill] - hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] - B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] - C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] - dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H] + hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D] + B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] + C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N] + dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H] - cu_seqlens = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=device), - torch.cumsum(seq_len_prefill, dim=0), - ], - dim=0, - ) seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32) seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1) - initial_states = chunk_indices = chunk_offsets = None + initial_states = None if torch.any(use_initial_states[:num_prefill]): initial_states = torch.where( use_initial_states[:num_prefill, None, None, None], @@ -106,6 +181,11 @@ def _triton_cached_ssm( chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( cu_seqlens, chunk_size ) + + else: + chunk_indices = None + chunk_offsets = None + y_prefill, varlen_states = mamba_chunk_scan_combined( hs_prefill, dt_prefill, @@ -125,46 +205,43 @@ def _triton_cached_ssm( dt_limit=(time_step_limit[0], time_step_limit[1]), return_final_states=False, return_varlen_states=True, + mamba_ssm_cache_dtype=ssm_state_cache.dtype, ) - y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype) + y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype) ssm_state_cache.index_copy_( - 0, slot_idx[:num_prefill].to(torch.long), varlen_states.to(ssm_state_cache.dtype) + 0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype) ) # Decode: batch single-token updates via selective_state_update if num_decode > 0: - total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item()) - slot_idx_decode = slot_idx[num_prefill:].to(torch.long) + slot_idx_decode = slot_idx[num_prefill:] - x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D] - B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N] - C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N] - dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H] + x_decode = hs_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H, D] + B_decode = B_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N] + C_decode = C_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N] + dt_decode = dt_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H] dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim) dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim) - dt_pre = torch.nn.functional.softplus(dt_hp + dt_bias_hp.to(dtype=dt_hp.dtype)) - dt_pre = torch.clamp(dt_pre, time_step_limit[0], time_step_limit[1]) A_full = A[..., None, None].expand(num_heads, head_dim, ssm_state_size) D_full = D[..., None].expand(num_heads, head_dim) - dt_bias_zero = torch.zeros_like(dt_bias_hp) y_dec = selective_state_update( ssm_state_cache, x_decode, - dt_pre, + dt_hp, A_full, B_decode, C_decode, D=D_full, z=None, - dt_bias=dt_bias_zero, - dt_softplus=False, + dt_bias=dt_bias_hp, + dt_softplus=True, state_batch_indices=slot_idx_decode, ) # [nd, H, D] - y_flat[total_prefill_tokens : total_prefill_tokens + num_decode] = y_dec.to(y_flat.dtype) + y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype)) return y @@ -184,6 +261,10 @@ def _triton_cached_ssm_fake( seq_start: torch.Tensor, # [num_seq] slot_idx: torch.Tensor, # [num_seq] use_initial_states: torch.Tensor, # [num_seq] + cu_seqlens: torch.Tensor, # [num_seq + 1] + chunk_indices: torch.Tensor, # [num_seq + 1] + chunk_offsets: torch.Tensor, # [num_seq + 1] + batch_info_tensor: torch.Tensor, # [2] # CACHES ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size] # CONSTANTS @@ -198,9 +279,7 @@ def _triton_cached_ssm_fake( ) -## Note: we reuse the existing metadata custom op and its registered fake from torch backend. - - +# TODO: consider inheriting from TorchBackendSSM instead of redefining everything @AttentionRegistry.register("triton_ssm") class TritonBackendSSM(AttentionDescriptor): @classmethod @@ -228,8 +307,9 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: - # Returns (seq_len, seq_start, slot_idx, use_initial_states) - return torch.ops.auto_deploy.torch_ssm_prepare_metadata, 4 + # Returns: seq_len, seq_start, slot_idx, use_initial_states, + # cu_seqlens, chunk_indices, chunk_offsets, batch_info_tensor + return torch.ops.auto_deploy.triton_ssm_prepare_metadata, 8 @classmethod def get_cache_initializers( @@ -247,6 +327,9 @@ def get_cache_initializers( else: ssm_state_size = max(1, B_fake.shape[-1]) + # extract ssm_state_dtype from cache_config or hs_fake + ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype + def _get_ssm_cache(si: SequenceInfo): return torch.empty( si.max_batch_size, @@ -254,7 +337,7 @@ def _get_ssm_cache(si: SequenceInfo): head_dim, ssm_state_size, device=si.device, - dtype=cache_config.dtype or hs_fake.dtype, + dtype=ssm_state_dtype, ) return {"ssm_state_cache": _get_ssm_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index ea68da9e508..2a0748783f4 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -182,6 +182,7 @@ def prepare_fused_mla_metadata( pages_per_seq: torch.Tensor, slot_idx: torch.Tensor, page_size: int, + chunk_size: int, ) -> List[torch.Tensor]: num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) @@ -196,7 +197,7 @@ def prepare_fused_mla_metadata( @prepare_fused_mla_metadata.register_fake def prepare_fused_mla_metadata_fake( - position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size ): return ( torch.empty_like(seq_len), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index df2d4b24c59..ddfd093d5c2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -363,6 +363,7 @@ def torch_backend_prepare_metadata( pages_per_seq: torch.Tensor, slot_idx: torch.Tensor, page_size: int, + chunk_size: int, ) -> List[torch.Tensor]: """Prepare metadata for torch backend attention (similar to triton backend).""" num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 34e0c5a988d..1ca4a605840 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -291,6 +291,7 @@ def prepare_fused_mha_metadata( pages_per_seq: torch.Tensor, slot_idx: torch.Tensor, page_size: int, + chunk_size: int, ) -> List[torch.Tensor]: # TODO: maybe use slot_idx instead of pages_per_seq?? num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) @@ -308,7 +309,7 @@ def prepare_fused_mha_metadata( # SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_fused_mha_metadata.register_fake def prepare_fused_mha_metadata_fake( - position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size ): num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) return ( diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index b5fb106e10b..71b4b8b2c5c 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -194,6 +194,11 @@ def get_sharding_config(self) -> Dict: """Returns the sharding config for this model.""" return self._sharding_config + @property + def chunk_size(self) -> Optional[int]: + """Returns the chunk size for this model.""" + return None + def get_cache_config(self) -> CacheConfig: """Return the cache configuration for the model. diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 41b5c90214c..6f48cb511c6 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -124,6 +124,12 @@ def vocab_size_padded(self) -> Optional[int]: model_config, _ = self._get_model_config() return getattr(model_config, "vocab_size", None) + @property + def chunk_size(self) -> Optional[int]: + """Returns the chunk size for this model.""" + model_config, _ = self._get_model_config() + return getattr(model_config, "chunk_size", None) + def _recursive_update_config( self, config: PretrainedConfig, update_dict: Dict[str, Any] ) -> Tuple[PretrainedConfig, Dict[str, Any]]: diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 0b6ba4921b7..f90b7c4c8be 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -121,8 +121,8 @@ def build_from_config(cls, ad_config: LlmArgs): page_size=attn_page_size, max_num_tokens=max_num_tokens, vocab_size_padded=factory.vocab_size_padded, + chunk_size=factory.chunk_size, ) - # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. @@ -167,7 +167,7 @@ def __init__( # build model self.model = get_inference_model(self.cache_seq_interface) - + print(self.llm_args) # start fresh with fixed seed torch.manual_seed(42) @@ -324,7 +324,6 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer torch.cuda.set_device(rank) port = mpi_dist.broadcast(dist.get_free_port()) # use MPI broadcast to pick a free port dist.initialize_or_skip(rank, world_size, port) - # some config assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported" diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py new file mode 100644 index 00000000000..3acc8e1f80f --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_causal_conv.py @@ -0,0 +1,120 @@ +"""Fusion transform for fusing activation functions into causal_conv1d operations.""" + +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + + +def _match_causal_conv_activation_pattern( + graph: GraphModule, + target_op, +) -> List[Tuple[Node, Node, str]]: + """ + Match the causal_conv + activation pattern in the graph. + + The pattern corresponds to: + conv_out = cuda_cached_causal_conv1d(...) + out = activation(conv_out) + + Args: + graph: The graph module to search + target_op: The target causal conv op to match + + Returns: + A list of tuples (conv_node, activation_node, activation_name) for each match + """ + matches = [] + + for node in graph.nodes: + if not is_op(node, target_op): + continue + + # Check if this node has exactly one user and it's an activation + if len(node.users) != 1: + continue + + activation_node = list(node.users.keys())[0] + if activation_node.op != "call_function": + continue + + # Detect activation type + activation_name: Optional[str] = None + if activation_node.target in (torch.ops.aten.silu.default, F.silu): + activation_name = "silu" + # Can extend to support more activations here: + # elif activation_node.target in (torch.ops.aten.gelu.default, F.gelu): + # activation_name = "gelu" + + if activation_name is not None: + matches.append((node, activation_node, activation_name)) + + return matches + + +@TransformRegistry.register("fuse_causal_conv_activation") +class FuseCausalConvActivation(BaseTransform): + """Fuses activation functions into cached CUDA causal_conv1d operations. + + This transform detects patterns like: + conv_out = cuda_cached_causal_conv1d(...) + out = silu(conv_out) + + And replaces them with: + out = cuda_cached_causal_conv1d(..., activation="silu") + + This optimization allows the backend CUDA kernels to fuse the activation, + reducing memory bandwidth and improving performance. + + Note: This runs AFTER insert_cached_causal_conv, so it operates on the + cached CUDA operations, not the uncached torch operations. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + + # Step 1: Identify causal_conv + activation pattern + matches = _match_causal_conv_activation_pattern( + graph, + target_op=torch.ops.auto_deploy.cuda_cached_causal_conv1d, + ) + + # Step 2: Replace matched patterns with fused version + for conv_node, activation_node, activation_name in matches: + with graph.inserting_after(conv_node): + # Create new call with fused activation + # Replace the last arg (activation=None) with activation_name + new_args = list(conv_node.args[:-1]) + [activation_name] + fused_node = graph.call_function( + torch.ops.auto_deploy.cuda_cached_causal_conv1d, + args=tuple(new_args), + ) + + # Replace all uses of activation_node with fused_node + activation_node.replace_all_uses_with(fused_node) + + # Remove the old nodes + graph.erase_node(activation_node) + graph.erase_node(conv_node) + + gm.recompile() + + info = TransformInfo( + skipped=False, + num_matches=len(matches), + is_clean=False, + has_valid_shapes=False, + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 055c93744e9..7594af22619 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -19,7 +19,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: if not is_op(node, torch.ops.auto_deploy.torch_moe): continue - (mlp_style_val,) = extract_op_args(node, "mlp_style") + (mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn") hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = ( extract_op_args( @@ -50,7 +50,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0) new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}" # Triton fused MoE op supports mlp only. - replacement_op = torch.ops.auto_deploy.triton_moe_fused + replacement_op = torch.ops.auto_deploy.trtllm_moe_fused else: raise ValueError(f"Unknown mlp_style: {mlp_style_val}") @@ -75,6 +75,10 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int: graph.get_attr(new_key_w_up), graph.get_attr(new_key_w_down), ), + kwargs={ + "mlp_style": mlp_style_val, + "act_fn": act_fn_val, + }, ) node.replace_all_uses_with(new_node) @@ -706,7 +710,7 @@ def get_param_or_buffer(target): # Create new node with get_attr for stacked parameters with graph.inserting_before(node): new_node = graph.call_function( - torch.ops.auto_deploy.triton_quant_fp8_moe, + torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused, args=( hidden_states, selected_experts, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 5a2b8485d6b..ecf42d0b238 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -8,7 +8,12 @@ from pydantic import Field from torch.fx import GraphModule, Node -from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegistry, Constant +from ...custom_ops.attention_interface import ( + AttentionDescriptor, + AttentionRegistry, + CacheConfig, + Constant, +) from ...distributed.common import all_gather_object, get_world_size from ...distributed.common import is_initialized as is_distributed_initialized from ...models.factory import ModelFactory @@ -66,6 +71,9 @@ class InsertCachedAttentionConfig(TransformConfig): """Configuration for the insert cached attention transform.""" backend: Optional[str] = Field(default=None, description="The attention backend to use.") + cache_config: CacheConfig = Field( + default_factory=CacheConfig, description="The custom cache configuration to use." + ) @TransformRegistry.register("insert_cached_attention") @@ -137,7 +145,9 @@ def _apply( """Replace uncached source attention node with corresponding cached attn node.""" attn_descriptor = self.attn_descriptor - cache_config = factory.get_cache_config() + # run field-wise or to combine the cache config from the transform and the factory + # the transform config takes precedence over the factory config + cache_config = self.config.cache_config | factory.get_cache_config() # Get all attention nodes and their info objects source_op = attn_descriptor.get_source_attention_op() diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 6c1a13bec8e..ece625a8854 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -24,6 +24,21 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: torch.bmm(a, b, out=out) +from enum import IntEnum + + +class ActivationType(IntEnum): + Gelu = 0 + Relu = 1 + Silu = 2 + Swiglu = 3 + Geglu = 4 + SwigluBias = 5 + Relu2 = 6 + Identity = 7 + InvalidType = 8 + + class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() @@ -52,6 +67,7 @@ def __init__( use_mxfp8_act_scaling: bool, min_latency_mode: bool, use_fused_finalize: bool, + activation_type: ActivationType, unpadded_hidden_size: Optional[int] = None, ): self.x_dtype = x_dtype @@ -72,6 +88,7 @@ def __init__( self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode self.use_fused_finalize = use_fused_finalize + self.activation_type = activation_type self.unpadded_hidden_size = unpadded_hidden_size if unpadded_hidden_size is not None else 0 instance_key = (x_dtype, weight_dtype, output_dtype, @@ -117,6 +134,7 @@ def forward( gemm_idx, tactic, do_preparation, + self.activation_type, self.unpadded_hidden_size, ) @@ -153,12 +171,13 @@ def fused_moe( tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, + activation_type: int = ActivationType.Swiglu, unpadded_hidden_size: Optional[int] = None, out_tensor: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: tuner = AutoTuner.get() - + activation_type = ActivationType(activation_type) # Only the non-alltoall case is considered for profiling in the warmup phase. # Therefore, to get the correct tactics during the actual inference, the inputs to the tuner should be the same as when not using alltoall. if enable_alltoall: @@ -189,6 +208,7 @@ def fused_moe( use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, use_fused_finalize=use_fused_finalize, + activation_type=activation_type, unpadded_hidden_size=unpadded_hidden_size, ) @@ -223,8 +243,8 @@ def fused_moe( swizzled_input_sf, swiglu_alpha, swiglu_beta, swiglu_limit, tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank, enable_alltoall, min_latency_mode, - [gemm_tactic_1, gemm_tactic_2], unpadded_hidden_size, - tuner_num_tokens, out_tensor) + [gemm_tactic_1, gemm_tactic_2], activation_type, + unpadded_hidden_size, tuner_num_tokens, out_tensor) return output if min_latency_mode else [output] @@ -260,6 +280,7 @@ def _(input: torch.Tensor, tune_max_num_tokens: int = 8192, tuner_num_tokens: Optional[int] = None, tuner_top_k: Optional[int] = None, + activation_type: ActivationType = ActivationType.Swiglu, unpadded_hidden_size: Optional[int] = None, out_tensor: Optional[torch.Tensor] = None): seq_len = input.shape[0] diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index 7534a0d22ad..56ae32d34ea 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -1,3 +1,4 @@ +import json import uuid from functools import partial from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal, @@ -185,6 +186,36 @@ def parse_chat_message_content( content, mm_data_tracker, ) + if role == "assistant": + result.update(_parse_assistant_message_content(message)) + elif role == "tool": + result.update(_parse_tool_message_content(message)) + return result + + +# Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406 +def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]: + result = {} + tool_calls = message.get("tool_calls") + if tool_calls is not None: + result["tool_calls"] = [] + for item in tool_calls: + if content := item["function"].get("arguments"): + if isinstance(content, str): + item["function"]["arguments"] = json.loads(content) + else: + item["function"]["arguments"] = content + else: + item["function"]["arguments"] = {} + result["tool_calls"].append(item) + + return result + + +def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]: + result = {} + if "tool_call_id" in message: + result["tool_call_id"] = message["tool_call_id"] return result diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index af8111d1f07..80695c7366f 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -396,6 +396,12 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False): """Enables custom roles in the Chat Completion API.""" + + # This is so custom fields not in any of the `ChatCompletionMessageParam` defined by OpenAI + # are still allowed. + # Examples include: assistant messages with `reasoning` / `reasoning_content`. + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + role: Required[str] """The role of the message's author.""" diff --git a/tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py b/tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py new file mode 100644 index 00000000000..7b3648547b5 --- /dev/null +++ b/tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py @@ -0,0 +1,344 @@ +# Adapted from: https://raw.githubusercontent.com/sgl-project/sglang/d8fcbaa38da95201914a1277971044ee66837b26/python/sglang/srt/function_call/qwen3_coder_detector.py + +import ast +import html +import json +import re +from typing import Any, Dict, List, Tuple + +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.openai_protocol import ChatCompletionToolsParam as Tool +from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser +from tensorrt_llm.serve.tool_parser.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) + + +def _safe_val(raw: str) -> Any: + raw = html.unescape(raw.strip()) + try: + return json.loads(raw) + except Exception: + try: + return ast.literal_eval(raw) + except Exception: + return raw + + +class Qwen3CoderToolParser(BaseToolParser): + """Tool parser for Qwen 3 models. + + Assumes function call format: + + + + pwd && ls + + + + """ + + def __init__(self): + super().__init__() + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + self.tool_call_prefix: str = "(.*?)|(.*?)$", re.DOTALL + ) + self.tool_call_function_regex = re.compile( + r"|| bool: + return self.tool_call_start_token in text + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + normal, calls = self._extract(text, tools) + return StreamingParseResult(normal_text=normal, calls=calls) + + def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: + self._buf += new_text + normal = "" + calls: List[ToolCallItem] = [] + + # Build tool indices for validation + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + while True: + # If we're not in a tool call and don't see a start token, return normal text + if not self._in_tool_call and self.tool_call_start_token not in self._buf: + normal += self._buf + self._buf = "" + break + + # Look for tool call start + if not self._in_tool_call: + s = self._buf.find(self.tool_call_start_token) + if s == -1: + normal += self._buf + self._buf = "" + break + + normal += self._buf[:s] + self._buf = self._buf[s:] + + self._in_tool_call = True + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + self._streamed_parameters = {} + + # Remove the start token + self._buf = self._buf[len(self.tool_call_start_token) :] + continue + + # We're in a tool call, try to parse function name if not sent yet + if not self._function_name_sent: + # Look for function name pattern: + function_match = re.search(r"]+)>", self._buf) + if function_match: + function_name = function_match.group(1).strip() + + # Validate function name + if function_name in self._tool_indices: + self._current_function_name = function_name + self._function_name_sent = True + + # Initialize tool call tracking + if self.current_tool_id == -1: + self.current_tool_id = 0 + + # Ensure tracking arrays are large enough + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Store tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": function_name, + "arguments": {}, + } + + # Send tool name with empty parameters + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ) + + # Remove the processed function declaration + self._buf = self._buf[function_match.end() :] + continue + else: + # Invalid function name, reset state + logger.warning(f"Invalid function name: {function_name}") + self._reset_streaming_state() + normal += self._buf + self._buf = "" + break + else: + # Function name not complete yet, wait for more text + break + + # Parse parameters incrementally + if self._function_name_sent: + # Process parameters and get any calls to emit + parameter_calls = self._parse_and_stream_parameters(self._buf) + calls.extend(parameter_calls) + + # Check if tool call is complete + if self.tool_call_end_token in self._buf: + end_pos = self._buf.find(self.tool_call_end_token) + + # Add closing brace to complete the JSON object + current_streamed = self.streamed_args_for_tool[self.current_tool_id] + if current_streamed: + # Count opening and closing braces to check if JSON is complete + open_braces = current_streamed.count("{") + close_braces = current_streamed.count("}") + if open_braces > close_braces: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters="}", + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = ( + current_streamed + "}" + ) + + # Complete the tool call + self._buf = self._buf[end_pos + len(self.tool_call_end_token) :] + self._reset_streaming_state() + self.current_tool_id += 1 + continue + else: + # Tool call not complete yet, wait for more text + break + + return StreamingParseResult(normal_text=normal, calls=calls) + + def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]: + """Parse complete parameter blocks from text and return any tool call items to emit. + + This method: + 1. Finds all complete blocks + 2. Parses them into a dictionary + 3. Compares with current parameters and generates diff if needed + 4. Updates internal state + + Args: + text_to_parse: The text to search for parameter blocks + + Returns: + List of ToolCallItem objects to emit (may be empty) + """ + calls: List[ToolCallItem] = [] + + # Find all complete parameter patterns + param_matches = list( + re.finditer(r"]+)>(.*?)", text_to_parse, re.DOTALL) + ) + + # Build new parameters dictionary + new_params = {} + for match in param_matches: + param_name = match.group(1).strip() + param_value = match.group(2) + new_params[param_name] = _safe_val(param_value) + + # Calculate parameter diff to stream with proper incremental JSON building + if new_params != self._current_parameters: + previous_args_json = self.streamed_args_for_tool[self.current_tool_id] + + # Build incremental JSON properly + if not self._current_parameters: + # First parameter(s) - start JSON object but don't close it yet + items = [] + for key, value in new_params.items(): + items.append( + f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + ) + json_fragment = "{" + ", ".join(items) + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = json_fragment + + else: + # Additional parameters - add them incrementally + new_keys = set(new_params.keys()) - set(self._current_parameters.keys()) + if new_keys: + # Build the continuation part (no closing brace yet) + continuation_parts = [] + for key in new_keys: + value = new_params[key] + continuation_parts.append( + f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + ) + + json_fragment = ", " + ", ".join(continuation_parts) + + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_fragment, + ) + ) + self.streamed_args_for_tool[self.current_tool_id] = ( + previous_args_json + json_fragment + ) + + # Update current state + self._current_parameters = new_params + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params + + return calls + + def _reset_streaming_state(self): + """Reset streaming state for the next tool call.""" + self._in_tool_call = False + self._function_name_sent = False + self._current_function_name = "" + self._current_parameters = {} + self._streamed_parameters = {} + self.current_tool_name_sent = False + + def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: + normal_parts: List[str] = [] + calls: List[ToolCallItem] = [] + cursor = 0 + while True: + s = text.find(self.tool_call_start_token, cursor) + if s == -1: + normal_parts.append(text[cursor:]) + break + normal_parts.append(text[cursor:s]) + e = text.find(self.tool_call_end_token, s) + if e == -1: + normal_parts.append(text[s:]) + break + block = text[s : e + len(self.tool_call_end_token)] + cursor = e + len(self.tool_call_end_token) + calls.extend(self._parse_block(block, tools)) + return "".join(normal_parts), calls + + def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: + res: List[ToolCallItem] = [] + for m in self.tool_call_function_regex.findall(block): + txt = m[0] if m[0] else m[1] + if ">" not in txt: + continue + idx = txt.index(">") + fname = txt[:idx].strip() + body = txt[idx + 1 :] + params: Dict[str, Any] = {} + for pm in self.tool_call_parameter_regex.findall(body): + ptxt = pm[0] if pm[0] else pm[1] + if ">" not in ptxt: + continue + pidx = ptxt.index(">") + pname = ptxt[:pidx].strip() + pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n") + params[pname] = _safe_val(pval) + raw = {"name": fname, "arguments": params} + try: + # TODO: fix idx in function call, the index for a function + # call will always be -1 in parse_base_json + res.extend(self.parse_base_json(raw, tools)) + except Exception: + logger.warning("invalid tool call for %s dropped", fname) + return res + + def supports_structural_tag(self) -> bool: + return False + + def structure_info(self) -> _GetInfoFunc: + raise NotImplementedError diff --git a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py index 73b02510a67..8a9bbe298c1 100644 --- a/tensorrt_llm/serve/tool_parser/tool_parser_factory.py +++ b/tensorrt_llm/serve/tool_parser/tool_parser_factory.py @@ -1,12 +1,14 @@ from typing import Type from .base_tool_parser import BaseToolParser +from .qwen3_coder_parser import Qwen3CoderToolParser from .qwen3_tool_parser import Qwen3ToolParser class ToolParserFactory: parsers: dict[str, Type[BaseToolParser]] = { "qwen3": Qwen3ToolParser, + "qwen3_coder": Qwen3CoderToolParser, } @staticmethod diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 90f7bd42733..3bca2c6eced 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -73,6 +73,7 @@ l0_a10: # executor - unittest/executor/test_rpc.py # trtllm-serve CPU-only + - unittest/llmapi/apps/test_chat_utils.py - unittest/llmapi/apps/test_tool_parsers.py - unittest/llmapi/apps/test_harmony_channel_validation.py - condition: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 2c9e4a70720..81f76e8a669 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -82,6 +82,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): d, g, pm, + None, ) assert y.shape == (batch, seq, c) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py new file mode 100644 index 00000000000..576fd3ce2db --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -0,0 +1,430 @@ +""" +This file contains test functions copied from: +https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_cutlass_fused_moe.py +""" + +from typing import Callable + +import pytest +import torch +from _torch_test_utils import fp8_compatible, trtllm_ops_available # noqa: F401 +from torch.nn import functional as F + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType + +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +FP8_DTYPE = torch.float8_e4m3fn + + +def _is_hopper_or_later(): + return torch.cuda.get_device_capability(0) >= (8, 9) + + +def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]: + fp8_traits_max = FLOAT8_E4M3_MAX + fp8_traits_min = -FLOAT8_E4M3_MAX + fp8_max = torch.tensor(fp8_traits_max).float() + one = torch.tensor(1.0).float() + + x_max = x.abs().max().float() + scale = x_max / fp8_max + iscale = one / scale + out = (x.float() * iscale).clamp(fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) + return out, scale.view((1,)) + + +def gen_tensor(shape, dtype, stype=None, scale=1.0): + x = torch.randn(*shape, dtype=dtype).cuda() * scale + return x.to(stype) if stype else x + + +def cast_to_representable(x): + """ + Convert a tensor of floats to exactly representable in FP8 format to reduce quantization error in the test. + + returns: + x_dq: A tensor of floats that is exactly representable in FP8 format. + x_dq = dq(q(x, x_scale), x_scale) + where x_scale is computed using min-max range clipping. + """ + x_q, x_scale = dynamic_per_tensor_fp8_quant(x) + x_dq = x_q.to(x.dtype) * x_scale.to(x.dtype) + return x_dq + + +def compute_routing(router_logits: torch.Tensor, top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute routing weights and selected experts from router logits. + + Args: + router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts] + top_k (int): Number of experts to route to per token + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - routing_weights: Expert weights of shape [batch_size, top_k] + - selected_experts: Expert indices of shape [batch_size, top_k] + """ + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.float() + return routing_weights, selected_experts + + +def compute_with_experts( + num_experts, + x, + w31_weight, + w2_weight, + selected_experts, + routing_weights, + alpha=None, + beta=None, + limit=None, + activation_func="silu", +): + def relu2(x: torch.Tensor) -> torch.Tensor: + return torch.square(F.relu(x)) + + results = torch.zeros_like(x) + for expert_id in range(num_experts): + mask = selected_experts == expert_id + if not mask.sum(): + continue + batch_idx, nth_expert = torch.where(mask) + w31_expert = w31_weight[expert_id] # [2 * intermediate_size, hidden_size] + w2_expert = w2_weight[expert_id] # [hidden_size, intermediate_size] + + # Split w13 into w1 and w3 + w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0) + + expert_inputs = x[batch_idx] + if alpha is not None and limit is not None and beta is not None: + # SwiGLUBias + x1 = expert_inputs @ w1_expert.t() + x1 = x1.clamp_(min=None, max=limit) + x1_scaled = x1 * torch.sigmoid(alpha * x1) + x2 = expert_inputs @ w3_expert.t() + x2 = x2.clamp_(min=-limit, max=limit) + beta + + inter = x1_scaled * x2 + else: + if activation_func == "swiglu" or activation_func == "silu": + inter = F.silu(expert_inputs @ w1_expert.t()) * (expert_inputs @ w3_expert.t()) + else: + inter = relu2(expert_inputs @ w1_expert.t()) + + output = inter @ w2_expert.t() + results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output + return results.view_as(x) + + +def _get_test_data( + otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE +): + input_shape = (batch_size, hidden_size) + w31_shape = (num_experts, 2 * intermediate_size, hidden_size) + w2_shape = (num_experts, hidden_size, intermediate_size) + + x = cast_to_representable(gen_tensor(input_shape, otype, scale=X_GEN_SCALE)) + router_logits = gen_tensor((batch_size, num_experts), otype) + w31_weight = gen_tensor(w31_shape, otype, wtype) + w2_weight = gen_tensor(w2_shape, otype, wtype) + w31_empty_scales = torch.empty(num_experts, 2, dtype=otype).cuda() + w2_empty_scales = torch.empty(num_experts, 1, dtype=otype).cuda() + return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales + + +def _activation_type_from_str(activation_func: str) -> ActivationType: + return ActivationType.Swiglu if activation_func in ["swiglu", "silu"] else ActivationType.Relu2 + + +def _print_diff_if( + condition: Callable[[torch.Tensor], bool], + diff: torch.Tensor, + ad_test_output: torch.Tensor, + ref_output: torch.Tensor, +): + if condition(diff): + print("diff: " + "-" * 20) + print(f"{diff[:10]}") + print("test_output: " + "-" * 20) + print(f"{ad_test_output[:10]}") + print("ref_output: " + "-" * 20) + print(f"{ref_output[:10]}") + + +# Test configurations +BATCH_SIZES = [ + 1, +] +HIDDEN_SIZES = [ + 128, +] +NUM_EXPERTS = [2] +TOP_K_VALUES = [2] +INTERMEDIATE_SIZES = [ + 128, +] +EP_NUM_EXPERTS = [8] +EP_TOP_K = [2] + + +F16_TEST_DTYPES = [ + (torch.float16, torch.float16, torch.float16), + (torch.bfloat16, torch.bfloat16, torch.bfloat16), +] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("itype, otype, wtype", F16_TEST_DTYPES) +@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.skipif( + not _is_hopper_or_later() or not trtllm_ops_available(), + reason="Requires Hopper or later and trtllm support", +) +def test_trtllm_fused_moe( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + itype, + otype, + wtype, + activation_func, +): + # Skip invalid configurations + if top_k > num_experts: + pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") + + torch.manual_seed(42) + if activation_func in ["swiglu", "silu"]: + X_GEN_SCALE = 1.0 + else: + X_GEN_SCALE = 0.5 + + x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data( + otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE + ) + + routing_weights, selected_experts = compute_routing(router_logits, top_k) + ref_output = compute_with_experts( + num_experts, + x, + w31_weight, + w2_weight, + selected_experts, + routing_weights, + activation_func=activation_func, + ) + + torch.cuda.synchronize() + print("before fused_moe.cutlass_fused_moe") + + assert itype == torch.bfloat16 or itype == torch.float16, ( + "F16 test only supports bfloat16 or float16" + ) + assert otype == torch.bfloat16 or otype == torch.float16, ( + "F16 test only supports bfloat16 or float16" + ) + assert wtype == torch.bfloat16 or wtype == torch.float16, ( + "F16 test only supports bfloat16 or float16" + ) + + activation_type = _activation_type_from_str(activation_func) + + def get_fc1_expert_weights( + activation_func: str, w31_weight: torch.Tensor, w1_weight: torch.Tensor + ) -> torch.Tensor: + if activation_func == "relu2": + return w1_weight.contiguous() + else: + return w31_weight + + # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) + _, w1_weight = torch.chunk(w31_weight, 2, dim=1) + + ad_test_output = torch.ops.auto_deploy.trtllm_moe_fused( + x, + selected_experts.to(torch.int), + routing_weights, + w3_w1_stacked_weight=get_fc1_expert_weights(activation_func, w31_weight, w1_weight), + w2_stacked_weight=w2_weight, + mlp_style="mlp" if activation_func == "relu2" else "gated_mlp", + act_fn=activation_func, + ) + trtllm_test_output = torch.ops.trtllm.fused_moe( + x, + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=w1_weight.contiguous() if activation_func == "relu2" else w31_weight, + fc1_expert_biases=None, + fc2_expert_weights=w2_weight, + fc2_expert_biases=None, + output_dtype=otype, + quant_scales=[], + activation_type=activation_type, + )[0].view(x.shape) + + torch.cuda.synchronize() + + diff = (ref_output - ad_test_output).abs() + print(f"max diff: {diff.max()}") + torch.testing.assert_close(ad_test_output, trtllm_test_output, rtol=1e-6, atol=1e-6) + + _print_diff_if(lambda diff: diff.max() > 1e-1, diff, ad_test_output, ref_output) + torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1) + + +FP8_TEST_DTYPES = [ + (torch.float8_e4m3fn, torch.bfloat16, torch.float8_e4m3fn), + (torch.float8_e4m3fn, torch.float16, torch.float8_e4m3fn), +] + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("top_k", TOP_K_VALUES) +@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) +@pytest.mark.parametrize("itype, otype, wtype", FP8_TEST_DTYPES) +@pytest.mark.parametrize("activation_func", ["silu", "relu2"]) +@pytest.mark.skipif( + not fp8_compatible() or not trtllm_ops_available(), + reason="Requires fp8 and trtllm support", +) +def test_trtllm_fused_fp8moe( + batch_size, + hidden_size, + num_experts, + top_k, + intermediate_size, + itype, + otype, + wtype, + activation_func, +): + # Skip invalid configurations + if top_k > num_experts: + pytest.skip(f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})") + + assert itype == torch.float8_e4m3fn and wtype == torch.float8_e4m3fn, ( + "FP8 test only supports float8_e4m3fn" + ) + assert otype == torch.bfloat16 or otype == torch.float16, ( + "FP8 test only supports bfloat16 or float16 output type" + ) + + torch.manual_seed(42) + if activation_func in ["swiglu", "silu"]: + X_GEN_SCALE = 1.0 + else: + X_GEN_SCALE = 0.5 + + def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales): + # input_shape = (batch_size, hidden_size) + w31_shape = (num_experts, 2 * intermediate_size, hidden_size) + w2_shape = (num_experts, hidden_size, intermediate_size) + + w31_dequantized = gen_tensor(w31_weight.shape, otype) + w2_dequantized = gen_tensor(w2_weight.shape, otype) + for expert_id in range(num_experts): + w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) + w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) + w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) + w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) + w31_weight.data[expert_id].copy_(w31_quant) + w2_weight.data[expert_id].copy_(w2_quant) + w31_scales.data[expert_id].copy_(s31) + w2_scales.data[expert_id].copy_(s2) + w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31)) + w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2)) + return w31_dequantized, w2_dequantized + + x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data( + otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE + ) + + w31_dequantized, w2_dequantized = dequantize_weights( + w31_weight, w2_weight, w31_scales, w2_scales + ) + + routing_weights, selected_experts = compute_routing(router_logits, top_k) + ref_output = compute_with_experts( + num_experts, + x, + w31_dequantized, + w2_dequantized, + selected_experts, + routing_weights, + activation_func=activation_func, + ) + + # For fp8, the hidden_state expects quantized. + w3_scales, w1_scales = torch.chunk(w31_scales, 2, dim=-1) + + x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x) + hidden_states_scale = hidden_states_scale[0].detach().clone().cuda() + + w3_input_scale = torch.tensor([1.0]).cuda() + w2_input_scale = torch.tensor([1.0]).cuda() + quant_scales = [ + torch.squeeze(w1_scales * hidden_states_scale).float(), # gemm1 dequant scale + w3_input_scale[0], # gemm2 activation quant scale + torch.squeeze(w2_scales * w2_input_scale[0]).float(), # gemm2 dequant scale + hidden_states_scale, # gemm1 input dequant scale + ] + + torch.cuda.synchronize() + print("before fused_moe.cutlass_fused_moe") + + # (num_experts, 2 * intermediate_size, hidden_size) => (num_experts, intermediate_size, hidden_size) + w3_weight, w1_weight = torch.chunk(w31_weight, 2, dim=1) + + activation_type = _activation_type_from_str(activation_func) + + ad_test_output = torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused( + x, # Note! unquantized input is expected + selected_experts.to(torch.int), + routing_weights, + w1_weight=w1_weight.contiguous(), + w2_weight=w2_weight.contiguous(), + w3_weight=w3_weight.contiguous(), + w1_input_scale=hidden_states_scale.unsqueeze(0), + w2_input_scale=w2_input_scale, + w3_input_scale=w3_input_scale, + w1_weight_scale=w1_scales, + w2_weight_scale=w2_scales, + w3_weight_scale=w3_scales, + mlp_style="mlp" if activation_func == "relu2" else "gated_mlp", + act_fn=activation_func, + ) + + _ = torch.ops.trtllm.fused_moe( + x_quant, # Note! quantized input is expected + selected_experts.to(torch.int), + routing_weights, + fc1_expert_weights=w1_weight.contiguous() if activation_func == "relu2" else w31_weight, + fc1_expert_biases=None, + fc2_expert_weights=w2_weight, + fc2_expert_biases=None, + output_dtype=otype, + quant_scales=quant_scales, + activation_type=activation_type, + )[0].view(x_quant.shape) + torch.cuda.synchronize() + + diff = (ref_output - ad_test_output).abs() + print(f"max diff: {diff.max()}") + # assert trtllm_test_output is not None + # torch.testing.assert_close(ad_test_output, trtllm_test_output, rtol=1e-6, atol=1e-6) + + _print_diff_if(lambda diff: diff.max() > 1e-1, diff, ad_test_output, ref_output) + torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1) diff --git a/tests/unittest/llmapi/apps/test_chat_utils.py b/tests/unittest/llmapi/apps/test_chat_utils.py new file mode 100644 index 00000000000..f055c4fabb1 --- /dev/null +++ b/tests/unittest/llmapi/apps/test_chat_utils.py @@ -0,0 +1,179 @@ +from unittest.mock import MagicMock + +import pytest + +from tensorrt_llm.serve.chat_utils import parse_chat_message_content + + +@pytest.fixture +def mock_mm_data_tracker(): + """Create a mock MultimodalDataTracker for testing.""" + return MagicMock() + + +class TestParseAssistantMessages: + """Test suite for assistant role messages.""" + + @pytest.mark.parametrize("content", [None, "Hello, how can I help you?"]) + def test_assistant_message_no_tool_calls( + self, + mock_mm_data_tracker, + content, + ): + """Test parsing an assistant message with simple string content.""" + message = {"role": "assistant", "content": content} + + result = parse_chat_message_content(message, mock_mm_data_tracker) + + assert result["role"] == "assistant" + assert result["content"] == (content or "") + assert result["media"] == [] + assert "tool_calls" not in result + + @pytest.mark.parametrize( + "arguments", + [ + # JSON string. + '{"location": "San Francisco", "unit": "celsius"}', + # Python dict. + {"location": "San Francisco", "unit": "celsius"}, + ], + ) + def test_assistant_message_with_tool_calls_string_arguments( + self, mock_mm_data_tracker, arguments + ): + """Test parsing an assistant message with tool calls where arguments are JSON strings.""" + message = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": arguments, + }, + } + ], + } + + result = parse_chat_message_content(message, mock_mm_data_tracker) + + assert result == { + "role": "assistant", + "content": "", + "media": [], + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": {"location": "San Francisco", "unit": "celsius"}, + }, + } + ], + } + + def test_assistant_message_with_empty_tool_arguments(self, mock_mm_data_tracker): + """Test parsing an assistant message with tool calls that have no arguments.""" + message = { + "role": "assistant", + "content": "Foobar", + "tool_calls": [ + { + "id": "call_789", + "type": "function", + "function": {"name": "get_current_time", "arguments": None}, + } + ], + } + + result = parse_chat_message_content(message, mock_mm_data_tracker) + + expected = { + "role": "assistant", + "content": "Foobar", + "media": [], + "tool_calls": [ + { + "id": "call_789", + "type": "function", + "function": {"name": "get_current_time", "arguments": {}}, + } + ], + } + assert result == expected + + def test_assistant_message_with_multiple_tool_calls(self, mock_mm_data_tracker): + """Test parsing an assistant message with multiple tool calls.""" + message = { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"location": "New York"}'}, + }, + { + "id": "call_2", + "type": "function", + "function": {"name": "get_time", "arguments": {"timezone": "EST"}}, + }, + {"id": "call_3", "type": "function", "function": {"name": "no_args_function"}}, + ], + } + + result = parse_chat_message_content(message, mock_mm_data_tracker) + + expected = { + "role": "assistant", + "content": "", + "media": [], + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "get_weather", "arguments": {"location": "New York"}}, + }, + { + "id": "call_2", + "type": "function", + "function": {"name": "get_time", "arguments": {"timezone": "EST"}}, + }, + { + "id": "call_3", + "type": "function", + "function": {"name": "no_args_function", "arguments": {}}, + }, + ], + } + assert result == expected + + +class TestParseToolMessages: + """Test suite for tool role messages.""" + + @pytest.mark.parametrize("content", ["The weather in San Francisco is 72°F and sunny.", None]) + def test_tool_message_with_tool_call_id(self, mock_mm_data_tracker, content): + """Test parsing a tool message with tool_call_id.""" + message = {"role": "tool", "content": (content or ""), "tool_call_id": "call_123"} + + result = parse_chat_message_content(message, mock_mm_data_tracker) + + expected = {**message, "media": []} + assert result == expected + + def test_tool_message_without_tool_call_id(self, mock_mm_data_tracker): + """Test parsing a tool message without tool_call_id.""" + message = { + "role": "tool", + "content": "Database query completed successfully.", + } + + result = parse_chat_message_content(message, mock_mm_data_tracker) + + expected = {**message, "media": []} + assert result == expected diff --git a/tests/unittest/llmapi/apps/test_tool_parsers.py b/tests/unittest/llmapi/apps/test_tool_parsers.py index 511f6a47fb0..a968a8058fd 100644 --- a/tests/unittest/llmapi/apps/test_tool_parsers.py +++ b/tests/unittest/llmapi/apps/test_tool_parsers.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import json +from typing import NamedTuple import pytest @@ -21,6 +23,8 @@ FunctionDefinition) from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser from tensorrt_llm.serve.tool_parser.core_types import StructureInfo +from tensorrt_llm.serve.tool_parser.qwen3_coder_parser import \ + Qwen3CoderToolParser from tensorrt_llm.serve.tool_parser.qwen3_tool_parser import Qwen3ToolParser @@ -66,7 +70,7 @@ def sample_tools(): } }, "required": ["query"] - })) + })), ] @@ -115,7 +119,7 @@ def test_get_tool_indices(self, sample_tools): parser = ConcreteToolParser() indices = parser._get_tool_indices(sample_tools) - assert len(indices) == 2 + assert len(indices) == len(sample_tools) assert indices["get_weather"] == 0 assert indices["search_web"] == 1 @@ -320,113 +324,215 @@ def test_structure_info(self): # ============================================================================ -class TestQwen3ToolParser: - """Test suite for Qwen3ToolParser class.""" +class ToolParserTestCases(NamedTuple): + has_tool_call_true: str + detect_and_parse_single_tool: tuple[ + # Input text. + str, + # Expected `normal_text`. + str, + # Expected `name`. + str, + # Expected `parameters`. + dict, + ] + detect_and_parse_multiple_tools: tuple[ + # Input text. + str, + # Expected names. + tuple[str], + ] + detect_and_parse_malformed_tool: str + detect_and_parse_with_parameters_key: tuple[ + # Input text. + str, + # Expected `name`. + str, + # Expected `parameters`. + dict, + ] + parse_streaming_increment_partial_bot_token: str + undefined_tool: str - def test_initialization(self): - """Test that Qwen3ToolParser initializes correctly.""" - parser = Qwen3ToolParser() - assert parser.bot_token == "\n" - assert parser.eot_token == "\n" - assert parser.tool_call_separator == "\n" - assert parser._normal_text_buffer == "" +class BaseToolParserTestClass: + """Base class from which tests for actual implementations can be extended. - def test_has_tool_call_true(self): - """Test has_tool_call returns True when tool call is present.""" - parser = Qwen3ToolParser() - text = 'Some text \n{"name":"get_weather"}\n' + NOTE: the name deliberately ends with `Class` so that `pytest` does not pick it up for execution + automatically. + """ - assert parser.has_tool_call(text) is True + @abc.abstractmethod + def make_parser(self): + ... + + @property + def make_tool_parser_test_cases(self) -> ToolParserTestCases: + ... + + @pytest.fixture + def parser(self): + return self.make_parser() + + @pytest.fixture(scope="class") + def tool_parser_test_cases(self) -> ToolParserTestCases: + return self.make_tool_parser_test_cases() - def test_has_tool_call_false(self): + def test_has_tool_call_false(self, parser): """Test has_tool_call returns False when no tool call present.""" - parser = Qwen3ToolParser() text = "Just some regular text without tool calls" assert parser.has_tool_call(text) is False - def test_detect_and_parse_no_tool_call(self, sample_tools): + def test_has_tool_call_true(self, parser, tool_parser_test_cases): + """Test has_tool_call returns True when tool call is present.""" + text = tool_parser_test_cases.has_tool_call_true + + assert parser.has_tool_call(text) is True + + def test_detect_and_parse_no_tool_call(self, sample_tools, parser): """Test detect_and_parse with text containing no tool calls.""" - parser = Qwen3ToolParser() text = "This is just a regular response." result = parser.detect_and_parse(text, sample_tools) - assert result.normal_text == "This is just a regular response." + assert result.normal_text == text assert len(result.calls) == 0 - def test_detect_and_parse_single_tool(self, sample_tools): + def test_detect_and_parse_single_tool(self, sample_tools, parser, + tool_parser_test_cases): """Test detect_and_parse with a single tool call.""" - parser = Qwen3ToolParser() - text = 'Normal text\n\n{"name":"get_weather","arguments":{"location":"NYC"}}\n' + text, normal_text, name, parameters = tool_parser_test_cases.detect_and_parse_single_tool result = parser.detect_and_parse(text, sample_tools) - assert result.normal_text == "Normal text" + assert result.normal_text == normal_text assert len(result.calls) == 1 - assert result.calls[0].name == "get_weather" - assert json.loads(result.calls[0].parameters) == {"location": "NYC"} + assert result.calls[0].name == name + assert json.loads(result.calls[0].parameters) == parameters - def test_detect_and_parse_multiple_tools(self, sample_tools): + def test_detect_and_parse_multiple_tools(self, sample_tools, parser, + tool_parser_test_cases): """Test detect_and_parse with multiple tool calls.""" - parser = Qwen3ToolParser() - text = ( - '\n{"name":"get_weather","arguments":{"location":"LA"}}\n\n' - '\n{"name":"search_web","arguments":{"query":"AI"}}\n' - ) + text, call_names = tool_parser_test_cases.detect_and_parse_multiple_tools result = parser.detect_and_parse(text, sample_tools) - assert len(result.calls) == 2 - assert result.calls[0].name == "get_weather" - assert result.calls[1].name == "search_web" + assert tuple(call.name for call in result.calls) == call_names - def test_detect_and_parse_malformed_json(self, sample_tools): - """Test detect_and_parse handles malformed JSON gracefully.""" - parser = Qwen3ToolParser() - text = '\n{"name":"get_weather","arguments":MALFORMED}\n' + def test_detect_and_parse_malformed_tool(self, sample_tools, parser, + tool_parser_test_cases): + """Test detect_and_parse handles malformed tool call output from the model gracefully.""" + text = tool_parser_test_cases.detect_and_parse_malformed_tool result = parser.detect_and_parse(text, sample_tools) - # Should return empty calls due to JSON parsing error assert len(result.calls) == 0 - def test_detect_and_parse_with_parameters_key(self, sample_tools): + def test_detect_and_parse_with_parameters_key(self, sample_tools, parser, + tool_parser_test_cases): """Test detect_and_parse handles 'parameters' key.""" - parser = Qwen3ToolParser() - text = '\n{"name":"search_web","parameters":{"query":"test"}}\n' + text, name, parameters = tool_parser_test_cases.detect_and_parse_with_parameters_key result = parser.detect_and_parse(text, sample_tools) assert len(result.calls) == 1 - assert result.calls[0].name == "search_web" - assert json.loads(result.calls[0].parameters) == {"query": "test"} + assert result.calls[0].name == name + assert json.loads(result.calls[0].parameters) == parameters - def test_parse_streaming_increment_normal_text(self, sample_tools): + def test_parse_streaming_increment_normal_text(self, sample_tools, parser): """Test streaming parser handles normal text without tool calls.""" - parser = Qwen3ToolParser() + text = "Hello, how can I help?" - result = parser.parse_streaming_increment("Hello, how can I help?", - sample_tools) + result = parser.parse_streaming_increment(text, sample_tools) - assert result.normal_text == "Hello, how can I help?" + assert result.normal_text == text assert len(result.calls) == 0 - def test_parse_streaming_increment_partial_bot_token(self, sample_tools): + def test_parse_streaming_increment_partial_bot_token( + self, sample_tools, parser, tool_parser_test_cases): """Test streaming parser buffers partial bot token.""" - parser = Qwen3ToolParser() + text = tool_parser_test_cases.parse_streaming_increment_partial_bot_token - # Send partial bot token - result = parser.parse_streaming_increment("\n{"name":"get_weather"}\n', + detect_and_parse_single_tool=( + # Input text. + ('Normal text\n' + '\n' + '{"name":"get_weather","arguments":{"location":"NYC"}}\n' + ''), + # Expected `normal_text`. + "Normal text", + # Expected `name`. + "get_weather", + # Expected `parameters`. + { + "location": "NYC" + }, + ), + detect_and_parse_multiple_tools=( + # Input text. + ('\n{"name":"get_weather","arguments":{"location":"LA"}}\n\n' + '\n{"name":"search_web","arguments":{"query":"AI"}}\n' + ), + # Expected names. + ("get_weather", "search_web"), + ), + detect_and_parse_malformed_tool= + ('\n{"name":"get_weather","arguments":MALFORMED}\n' + ), + detect_and_parse_with_parameters_key=( + # Input text. + ('\n{"name":"search_web","parameters":{"query":"test"}}\n' + ), + # Expected `name`. + "search_web", + # Expected `parameters`. + { + "query": "test" + }, + ), + parse_streaming_increment_partial_bot_token="\n{"name":"undefined_func","arguments":{}}\n', + ) + + def test_initialization(self, parser): + """Test that Qwen3ToolParser initializes correctly.""" + assert parser.bot_token == "\n" + assert parser.eot_token == "\n" + assert parser.tool_call_separator == "\n" + assert parser._normal_text_buffer == "" + + # NOTE: this is not put in the base class. Even though it could be made generic, the added logic + # to do so loses the clarity of this more direct approach. + def test_parse_streaming_increment_complete_tool_call( + self, sample_tools, parser): """Test streaming parser with complete tool call in chunks.""" - parser = Qwen3ToolParser() # Send bot token parser.parse_streaming_increment("\n", sample_tools) @@ -448,9 +554,9 @@ def test_parse_streaming_increment_complete_tool_call(self, sample_tools): assert len(result.calls) == 1 assert json.loads(result.calls[0].parameters) == {"location": "SF"} - def test_parse_streaming_increment_end_token_handling(self, sample_tools): + def test_parse_streaming_increment_end_token_handling( + self, sample_tools, parser): """Test streaming parser handles end token correctly.""" - parser = Qwen3ToolParser() # Send complete tool call parser.parse_streaming_increment( @@ -462,9 +568,8 @@ def test_parse_streaming_increment_end_token_handling(self, sample_tools): assert parser._normal_text_buffer == "" def test_parse_streaming_increment_multiple_tools_streaming( - self, sample_tools): + self, sample_tools, parser): """Test streaming parser handles multiple tool calls.""" - parser = Qwen3ToolParser() # First tool parser.parse_streaming_increment('\n', sample_tools) @@ -506,9 +611,8 @@ def test_structure_info_different_names(self): assert "search_web" in info2.begin assert info1.end == info2.end == "}\n" - def test_qwen3_format_compliance(self, sample_tools): + def test_qwen3_format_compliance(self, sample_tools, parser): """Test that Qwen3ToolParser follows the documented format structure.""" - parser = Qwen3ToolParser() # Test the exact format from the docstring text = '\n{"name":"get_weather", "arguments":{"location":"Tokyo"}}\n' @@ -519,16 +623,239 @@ def test_qwen3_format_compliance(self, sample_tools): assert result.calls[0].name == "get_weather" assert json.loads(result.calls[0].parameters) == {"location": "Tokyo"} - def test_undefined_tool_in_qwen3_format(self, sample_tools): - """Test Qwen3ToolParser handles undefined tool gracefully.""" - parser = Qwen3ToolParser() - text = '\n{"name":"undefined_func","arguments":{}}\n' - result = parser.detect_and_parse(text, sample_tools) +class TestQwen3CoderToolParser(BaseToolParserTestClass): + """Test suite for Qwen3CoderToolParser class.""" + + def make_parser(self): + return Qwen3CoderToolParser() + + def make_tool_parser_test_cases(self): + return ToolParserTestCases( + has_tool_call_true=("Some text \n" + "\n" + "NYC\n" + "\n" + ""), + detect_and_parse_single_tool=( + # Input text. + ("Normal text\n" + "\n" + "\n" + "NYC\n" + "\n" + ""), + # Expected `normal_text`. + "Normal text\n", + # Expected `name`. + "get_weather", + # Expected `parameters`. + { + "location": "NYC" + }, + ), + detect_and_parse_multiple_tools=( + # Input text. + ("\n" + "\n" + "LA\n" + "\n" + "\n" + "\n" + "\n" + "AI\n" + "\n" + ""), + # Expected names. + ("get_weather", "search_web"), + ), + detect_and_parse_malformed_tool=( + # Typo. + # NOTE: the regexes + logic in `Qwen3CoderToolParser` seems deliberately forgiving. + # For example, forgetting the closing `` is fine, as is the closing + # ``. However, the values returned in the function call information + # might be dubious as a result. + "\n" + "\n" + "San Francisco, CA\n" + "\n" + ""), + detect_and_parse_with_parameters_key=( + # Input text (Qwen3Coder uses "parameter", not "parameters"). + ("\n" + "\n" + "test\n" + "\n" + ""), + # Expected `name`. + "search_web", + # Expected `parameters`. + { + "query": "test" + }, + ), + parse_streaming_increment_partial_bot_token="", + undefined_tool=("\n" + "\n" + "value\n" + "\n" + ""), + ) - # Should not return any calls for undefined function + def test_parse_streaming_increment_complete_tool_call( + self, sample_tools, parser): + """Test streaming parser with complete tool call in chunks.""" + + # Send tool call start token + result = parser.parse_streaming_increment("\n", sample_tools) assert len(result.calls) == 0 + # Send function declaration + result = parser.parse_streaming_increment("\n", + sample_tools) + + # Should send tool name with empty parameters + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert result.calls[0].parameters == "" + + # Send parameter block + result = parser.parse_streaming_increment( + 'SF\n\n', + sample_tools) + + # Should stream parameters + assert len(result.calls) >= 1 + # Check that parameters were sent (could be in multiple chunks) + all_params = "".join(call.parameters for call in result.calls + if call.parameters) + assert "location" in all_params + assert "SF" in all_params + + def test_parse_streaming_increment_end_token_handling( + self, sample_tools, parser): + """Test streaming parser handles end token correctly.""" + + # Send complete tool call + parser.parse_streaming_increment( + "\n" + "\n" + "NYC\n" + "\n" + "", sample_tools) + + # Check buffer state - should be cleared after complete tool call + assert parser._buf == "" + assert parser._in_tool_call is False + + def test_parse_streaming_increment_multiple_tools_streaming( + self, sample_tools, parser): + """Test streaming parser handles multiple tool calls.""" + + # First tool. + parser.parse_streaming_increment("\n", sample_tools) + parser.parse_streaming_increment("\n", + sample_tools) + parser.parse_streaming_increment( + "NYC\n\n\n", + sample_tools) + + # Second tool. + parser.parse_streaming_increment("\n", sample_tools) + result = parser.parse_streaming_increment("\n", + sample_tools) + + # Should have started second tool. + assert result.calls[0].name == "search_web" + assert result.calls[0].parameters == "" + assert result.calls[0].tool_index == 1 + + def test_parse_streaming_increment_multiple_parameters( + self, sample_tools, parser): + """Test parser handles multiple parameters in a single function call.""" + + tool_def = ChatCompletionToolsParam( + type="function", + function=FunctionDefinition( + name="multi_param_func", + description="Function with multiple parameters", + parameters={ + "type": "object", + "properties": { + "param1": { + "type": "string" + }, + "param2": { + "type": "string" + }, + "param3": { + "type": "integer" + } + }, + "required": ["param1", "param2", "param3"] + })) + + text = ("\n" + "\n" + "value1\n" + "value2\n" + "42\n" + "\n" + "") + + result = parser.detect_and_parse(text, [tool_def]) + + assert len(result.calls) == 1 + assert result.calls[0].name == "multi_param_func" + assert json.loads(result.calls[0].parameters) == { + "param1": "value1", + "param2": "value2", + "param3": 42 + } + + def test_qwen3_coder_format_compliance( + self, + parser, + ): + """Test that Qwen3CoderToolParser follows the documented format structure.""" + + # Test the exact format from the docstring + text = ("\n" + "\n" + "\n" + "pwd && ls\n" + "\n" + "\n" + "") + + tool_def = ChatCompletionToolsParam( + type="function", + function=FunctionDefinition( + name="execute_bash", + description="Execute a bash command.", + parameters={ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The command to execute.", + }, + "unit": { + "type": "string", + } + }, + "required": ["command"], + }, + )) + + result = parser.detect_and_parse(text, [tool_def]) + + assert len(result.calls) == 1 + assert result.calls[0].name == "execute_bash" + assert json.loads(result.calls[0].parameters) == { + "command": "pwd && ls" + } + # ============================================================================ # Integration Tests