Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,30 @@ __forceinline__ __device__ float tanh_opt(float x)
#endif
}

template <typename T>
struct Relu2
{
static bool const kIsHeavy = false;

CUTLASS_HOST_DEVICE
T operator()(T threshold, T value) const
{
ReLu<T> relu_op;
multiplies<T> mul;
T val = relu_op(threshold, value);
return mul(val, val);
}

CUTLASS_HOST_DEVICE
T operator()(T value) const
{
ReLu<T> relu_op;
multiplies<T> mul;
T val = relu_op(value);
return mul(val, val);
}
};

} // namespace thread
} // namespace epilogue
} // namespace cutlass
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enum class ActivationType
Swiglu,
Geglu,
SwigluBias,
Relu2,
Identity,
InvalidType
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,7 @@ void MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::moeGemmBiasAct(
case ActivationType::Identity: runGemm<cutlass_extensions::EpilogueOpDefault>(inputs, hopper_inputs); break;
case ActivationType::Swiglu: runGemm<cutlass_extensions::EpilogueOpDefaultSilu>(inputs, hopper_inputs); break;
case ActivationType::Geglu: runGemm<cutlass_extensions::EpilogueOpDefaultFtGelu>(inputs, hopper_inputs); break;
case ActivationType::Relu2: TLLM_THROW("Relu2 is not supported."); break;
case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break;
default: TLLM_THROW("Invalid activation type."); break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2307,6 +2307,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
decltype(block_scaling_type)::value>, // Geglu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>, // SwigluBias
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
decltype(block_scaling_type)::value>, // Relu2
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
decltype(block_scaling_type)::value> // Identity
Expand Down
37 changes: 26 additions & 11 deletions cpp/tensorrt_llm/thop/moeOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
torch::optional<torch::Tensor> const& out_tensor)
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
{
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
Expand Down Expand Up @@ -328,6 +328,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
TORCH_CHECK(fc1_expert_weights.sizes()[0] == fc2_expert_weights.sizes()[0],
"fc1_expert_weights and fc2_expert_weights must have the same number of experts.");

ActivationType base_activation_type = activation_type.has_value()
? static_cast<ActivationType>(activation_type.value())
: ActivationType::Swiglu;
if (mUseINT8WoqPerChannel)
{
// Note: The weight shape for INT8 weight only quantization is different, e.g., fc2_expert_weights:
Expand All @@ -337,8 +340,16 @@ class FusedMoeRunner : public torch::CustomClassHolder
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be fc2_expert_weights inter size.");
if (isGatedActivation(base_activation_type))
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2,
"fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size.");
}
else
{
TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier,
"fc1_expert_weights inter size must be equal to fc2_expert_weights inter size.");
}
}

int experts_per_token = token_selected_experts.sizes()[1];
Expand Down Expand Up @@ -375,7 +386,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
int const num_experts_on_rank = fc2_expert_weights.sizes()[0];
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank);
ActivationType base_activation_type = ActivationType::Swiglu;

if (swiglu_alpha.has_value())
{
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
Expand Down Expand Up @@ -474,8 +485,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
torch::optional<torch::Tensor> const& swiglu_limit, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids,
torch::optional<int64_t> const& unpadded_hidden_size, torch::optional<int64_t> const& num_valid_tokens,
torch::optional<torch::Tensor> const& out_tensor)
torch::optional<int64_t> const& activation_type, torch::optional<int64_t> const& unpadded_hidden_size,
torch::optional<int64_t> const& num_valid_tokens, torch::optional<torch::Tensor> const& out_tensor)
{
std::lock_guard<std::mutex> lock(mMutex);

Expand Down Expand Up @@ -541,7 +552,9 @@ class FusedMoeRunner : public torch::CustomClassHolder
auto const num_experts_total = static_cast<int>(num_experts_on_rank * ep_size);
auto parallelism_config
= kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank, cluster_size, cluster_rank);
ActivationType base_activation_type = ActivationType::Swiglu;
ActivationType base_activation_type = activation_type.has_value()
? static_cast<ActivationType>(activation_type.value())
: ActivationType::Swiglu;
if (swiglu_alpha.has_value())
{
CHECK_INPUT(swiglu_alpha.value(), at::ScalarType::Float);
Expand Down Expand Up @@ -652,7 +665,8 @@ class FusedMoeRunner : public torch::CustomClassHolder
torch::optional<torch::Tensor> const& fc2_expert_biases, int64_t const top_k, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, int64_t const gemm_idx,
int64_t const profile_id, bool const do_preparation, int64_t const unpadded_hidden_size)
int64_t const profile_id, bool const do_preparation, int64_t const activation_type_int,
int64_t const unpadded_hidden_size)
{
std::lock_guard<std::mutex> lock(mMutex);

Expand All @@ -661,6 +675,7 @@ class FusedMoeRunner : public torch::CustomClassHolder
{
return;
}
ActivationType activation_type = static_cast<ActivationType>(activation_type_int);

int64_t const num_rows = input.sizes()[0];
int64_t hidden_size = fc2_expert_weights.sizes()[1];
Expand Down Expand Up @@ -715,14 +730,14 @@ class FusedMoeRunner : public torch::CustomClassHolder
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, unpadded_hidden_size > 0 ? unpadded_hidden_size : hidden_size, inter_size, group_size,
ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
activation_type, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config, enable_alltoall);
#else
mProfiler->init(*mKernelRunner.get(), mProfiler->mGemmToProfile,
tensorrt_llm::runtime::TorchUtils::dataType(activation_dtype),
tensorrt_llm::runtime::TorchUtils::dataType(mWeightDtype),
tensorrt_llm::runtime::TorchUtils::dataType(mOutputDtype), num_experts, static_cast<int>(top_k),
hidden_size, inter_size, group_size, ActivationType::Swiglu, USE_BIAS, USE_LORA, min_latency_mode,
hidden_size, inter_size, group_size, activation_type, USE_BIAS, USE_LORA, min_latency_mode,
/*need_weights*/ false, parallelism_config);
#endif

Expand Down
3 changes: 3 additions & 0 deletions examples/auto_deploy/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ benchmark_results.json
*.png
# ignore config files that users might put here for debugging
*.yaml
!nano_v3.yaml
!nano_v3_accuracy.yaml
!nano_v3_bench.yaml
23 changes: 23 additions & 0 deletions examples/auto_deploy/nano_v3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_batch_size: 384
max_seq_len: 65536 # tunable
enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
transforms:
detect_sharding:
sharding_source: ['factory', 'heuristic']
sharding_dims: ['ep', 'bmm']
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
23 changes: 23 additions & 0 deletions examples/auto_deploy/nano_v3_accuracy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_batch_size: 128
max_seq_len: 204800
enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
transforms:
detect_sharding:
sharding_source: ['factory', 'heuristic']
sharding_dims: ['ep', 'bmm']
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
cache_config:
mamba_dtype: float32
# mamba_dtype: null
23 changes: 23 additions & 0 deletions examples/auto_deploy/nano_v3_bench.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
runtime: trtllm
compile_backend: torch-cudagraph
max_batch_size: 384 # tunable
max_seq_len: 65536 # tunable
enable_chunked_prefill: true
attn_backend: flashinfer
model_factory: AutoModelForCausalLM
skip_loading_weights: false
free_mem_ratio: 0.9
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
kv_cache_config:
# disable kv_cache reuse since not supported for hybrid/ssm models
enable_block_reuse: false
transforms:
detect_sharding:
sharding_source: ['factory', 'heuristic']
sharding_dims: ['ep', 'bmm']
# tunable mamba cache dtype
# --> use float32 for accuracy and default (null) for speed
insert_cached_ssm_attention:
cache_config:
# mamba_dtype: float32
mamba_dtype: null
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def forward(self, *args, **kwargs) -> Any:

# retrieve output from buffer, cut to batch size, and unflatten
bs = args_batched[0].shape[0]
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
out_flat = [o_b[:bs] for o_b in self._out_buffer_flat]
return self._out_spec.unflatten(out_flat)


Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ transforms:
############################################################################################
# COMPILE MODEL
############################################################################################
fuse_causal_conv_activation:
stage: compile
compile_model:
stage: compile
run_per_gm: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union

import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator
from torch._ops import OpOverloadPacket
from torch.fx import Node
from torch.types import Number
Expand All @@ -24,11 +24,39 @@
Constant = Union[int, float, str, None]


@dataclass
class CacheConfig:
"""A dataclass to hold information how to configure the cache."""
class CacheConfig(BaseModel):
"""Cache configuration for attention-related dtypes."""

dtype: Optional[torch.dtype] = None
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="forbid",
)

dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")

@field_validator("dtype", "mamba_dtype", mode="before")
@classmethod
def _coerce_dtype(cls, value):
if value is None or isinstance(value, torch.dtype):
return value
if isinstance(value, str):
dtype = getattr(torch, value, None)
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
return dtype
return value

def __or__(self, other: "CacheConfig") -> "CacheConfig":
"""Combine two CacheConfig objects field-wise using Python's `or` semantics.

For each field, selects the first non-None value between `self` and `other`.
"""
if not isinstance(other, CacheConfig):
raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
merged_kwargs = {}
for field_name in type(self).model_fields.keys():
merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
return CacheConfig(**merged_kwargs)


class SequenceInfo:
Expand Down Expand Up @@ -88,6 +116,7 @@ def __init__(
page_size: int = 0,
max_num_tokens: Optional[int] = None,
vocab_size_padded: Optional[int] = None,
chunk_size: Optional[int] = None,
):
"""Initialize the SequenceInfo object.

Expand All @@ -114,7 +143,7 @@ def __init__(
self.max_batch_size = max_batch_size
self.page_size = page_size if page_size > 0 else max_seq_len
self.vocab_size_padded = vocab_size_padded

self.chunk_size = chunk_size
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
# (max_batch_size, max_seq_len) input in trtllm runtime.
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
Expand Down Expand Up @@ -165,7 +194,7 @@ def __init__(
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
"cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.long),
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
}
Expand All @@ -175,7 +204,7 @@ def __init__(
# NOTE: order of keys is relevant here!
self._uncached_arg_names = ("input_ids", "position_ids")
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
self._cached_constants = ("page_size",)
self._cached_constants = ("page_size", "chunk_size")
############################################################################################

# EXTRA TENSOR FIELDS ######################################################################
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def prepare_flashinfer_metadata(
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
chunk_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for flashinfer attention.

Expand Down Expand Up @@ -213,7 +214,7 @@ def prepare_flashinfer_metadata(
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_flashinfer_metadata.register_fake
def prepare_flashinfer_metadata_fake(
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
):
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import triton
import triton.language as tl

from tensorrt_llm._utils import nvtx_range


@triton.jit
def _write_zeros_to_output(
Expand Down Expand Up @@ -304,6 +306,7 @@ def _default_kernel_config(M: int, E: int, N: int, K: int, top_k: int) -> dict:
}


@nvtx_range("triton_moe_pack_routed_tokens")
def _pack_routed_tokens(
topk_ids: torch.Tensor,
M: int,
Expand Down
Loading