Skip to content

Commit 5eb77f9

Browse files
lucaslieWanli-Jiang
authored andcommitted
configurable kvcache/mamba cache
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent 72b2505 commit 5eb77f9

File tree

4 files changed

+55
-12
lines changed

4 files changed

+55
-12
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13-
from dataclasses import dataclass
1413
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union
1514

1615
import torch
16+
from pydantic import BaseModel, ConfigDict, Field, field_validator
1717
from torch._ops import OpOverloadPacket
1818
from torch.fx import Node
1919
from torch.types import Number
@@ -24,11 +24,39 @@
2424
Constant = Union[int, float, str, None]
2525

2626

27-
@dataclass
28-
class CacheConfig:
29-
"""A dataclass to hold information how to configure the cache."""
27+
class CacheConfig(BaseModel):
28+
"""Cache configuration for attention-related dtypes."""
3029

31-
dtype: Optional[torch.dtype] = None
30+
model_config = ConfigDict(
31+
arbitrary_types_allowed=True,
32+
extra="forbid",
33+
)
34+
35+
dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
36+
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
37+
38+
@field_validator("dtype", "mamba_dtype", mode="before")
39+
@classmethod
40+
def _coerce_dtype(cls, value):
41+
if value is None or isinstance(value, torch.dtype):
42+
return value
43+
if isinstance(value, str):
44+
dtype = getattr(torch, value, None)
45+
assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}"
46+
return dtype
47+
return value
48+
49+
def __or__(self, other: "CacheConfig") -> "CacheConfig":
50+
"""Combine two CacheConfig objects field-wise using Python's `or` semantics.
51+
52+
For each field, selects the first non-None value between `self` and `other`.
53+
"""
54+
if not isinstance(other, CacheConfig):
55+
raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}")
56+
merged_kwargs = {}
57+
for field_name in type(self).model_fields.keys():
58+
merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name)
59+
return CacheConfig(**merged_kwargs)
3260

3361

3462
class SequenceInfo:

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,17 @@ def get_cache_initializers(
347347
# Fallback: assume last dim is n_groups * state_size and choose a minimal positive size
348348
ssm_state_size = max(1, B_fake.shape[-1])
349349

350+
# extract ssm_state_dtype from cache_config or hs_fake
351+
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype
352+
350353
def _get_ssm_cache(si: SequenceInfo):
351354
return torch.empty(
352355
si.max_batch_size,
353356
num_heads,
354357
head_dim,
355358
ssm_state_size,
356359
device=si.device,
357-
dtype=cache_config.dtype or hs_fake.dtype,
360+
dtype=ssm_state_dtype,
358361
)
359362

360363
return {"ssm_state_cache": _get_ssm_cache}

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _triton_cached_ssm(
125125
dt_limit=(time_step_limit[0], time_step_limit[1]),
126126
return_final_states=False,
127127
return_varlen_states=True,
128+
mamba_ssm_cache_dtype=ssm_state_cache.dtype,
128129
)
129130

130131
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
@@ -198,9 +199,7 @@ def _triton_cached_ssm_fake(
198199
)
199200

200201

201-
## Note: we reuse the existing metadata custom op and its registered fake from torch backend.
202-
203-
202+
# TODO: consider inheriting from TorchBackendSSM instead of redefining everything
204203
@AttentionRegistry.register("triton_ssm")
205204
class TritonBackendSSM(AttentionDescriptor):
206205
@classmethod
@@ -247,14 +246,17 @@ def get_cache_initializers(
247246
else:
248247
ssm_state_size = max(1, B_fake.shape[-1])
249248

249+
# extract ssm_state_dtype from cache_config or hs_fake
250+
ssm_state_dtype = cache_config.mamba_dtype or hs_fake.dtype
251+
250252
def _get_ssm_cache(si: SequenceInfo):
251253
return torch.empty(
252254
si.max_batch_size,
253255
num_heads,
254256
head_dim,
255257
ssm_state_size,
256258
device=si.device,
257-
dtype=cache_config.dtype or hs_fake.dtype,
259+
dtype=ssm_state_dtype,
258260
)
259261

260262
return {"ssm_state_cache": _get_ssm_cache}

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from pydantic import Field
99
from torch.fx import GraphModule, Node
1010

11-
from ...custom_ops.attention_interface import AttentionDescriptor, AttentionRegistry, Constant
11+
from ...custom_ops.attention_interface import (
12+
AttentionDescriptor,
13+
AttentionRegistry,
14+
CacheConfig,
15+
Constant,
16+
)
1217
from ...distributed.common import all_gather_object, get_world_size
1318
from ...distributed.common import is_initialized as is_distributed_initialized
1419
from ...models.factory import ModelFactory
@@ -66,6 +71,9 @@ class InsertCachedAttentionConfig(TransformConfig):
6671
"""Configuration for the insert cached attention transform."""
6772

6873
backend: Optional[str] = Field(default=None, description="The attention backend to use.")
74+
cache_config: CacheConfig = Field(
75+
default_factory=CacheConfig, description="The custom cache configuration to use."
76+
)
6977

7078

7179
@TransformRegistry.register("insert_cached_attention")
@@ -137,7 +145,9 @@ def _apply(
137145
"""Replace uncached source attention node with corresponding cached attn node."""
138146
attn_descriptor = self.attn_descriptor
139147

140-
cache_config = factory.get_cache_config()
148+
# run field-wise or to combine the cache config from the transform and the factory
149+
# the transform config takes precedence over the factory config
150+
cache_config = self.config.cache_config | factory.get_cache_config()
141151

142152
# Get all attention nodes and their info objects
143153
source_op = attn_descriptor.get_source_attention_op()

0 commit comments

Comments
 (0)