Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions examples/llm-api/llm_kv_cache_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,84 @@
### :title KV Cache Connector
### :order 6
### :section Customization
'''
This script demonstrates the KV cache connector feature in TensorRT-LLM, which enables
custom persistence and reuse of KV cache blocks across different LLM instances.

**Scenario:**
The script implements a persistent KV cache connector that saves computed KV cache blocks
to disk and loads them back in subsequent runs, eliminating redundant computation for
recurring prompts.

**What is a KV Cache Connector?**

A KV cache connector is a customizable interface that allows you to:
1. **Save KV Cache:** Persist computed KV cache blocks to an external storage
(disk, database, distributed cache, etc.)
2. **Load KV Cache:** Retrieve previously computed cache blocks instead of recomputing them
3. **Share Cache Across Instances:** Reuse cache blocks across different LLM instances
or sessions, unlike regular block reuse which is limited to a single instance

**How It Works:**

This example implements a `PersistentKvCacheConnector` with two key components:

* **PersistentKvCacheConnectorLeader (Scheduler):**
- Hashes token sequences to create unique identifiers for each cache block
- Checks if cached blocks exist on disk for incoming requests
- Schedules load operations for cache hits
- Schedules save operations for newly computed blocks

* **PersistentKvCacheConnectorWorker:**
- Executes the actual load/save operations between GPU and disk
- Loads cached blocks from disk files into GPU memory
- Saves newly computed blocks from GPU to disk files

**Demonstration:**

The script processes the same prompt twice using two separate LLM instances:

1. **First Run (Instance 1):**
- The LLM computes the KV cache for the input prompt
- The connector saves the computed cache blocks to disk (as .pt files)
- The generation completes and the LLM instance is destroyed

2. **Second Run (Instance 2):**
- A new LLM instance is created with the same connector configuration
- When processing the same prompt, the connector finds matching cache blocks on disk
- The cache is loaded from disk instead of being recomputed
- **Expected Outcome:** Faster prefill as cache blocks are loaded rather than computed
- Both outputs should be identical, demonstrating deterministic cache reuse

**Key Benefits:**

- **Cross-Instance Cache Sharing:** Share computed caches across multiple LLM instances
- **Persistent Storage:** Cache survives beyond the lifetime of a single LLM instance
- **Custom Storage Backends:** Implement any storage mechanism (shown here: disk files)
- **Reduced Computation:** Eliminate redundant KV cache computation for repeated prompts

**How to Run:**

```bash
python llm_kv_cache_connector.py <model_path>
```

Example:
```bash
python llm_kv_cache_connector.py meta-llama/Llama-3.1-8B-Instruct
```

**Implementation Notes:**

- This example uses content-based hashing to identify cache blocks
- Cache files are stored in a temporary directory (cleaned up after the demo)
- The implementation is simplified and not optimized for production use
- Does not support chunked prefill in this example
- See `tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` for the full connector interface

**NOTE:** This example connector implementation is designed for demonstration purposes
and is NOT suitable for production use without additional optimizations and error handling.
'''

import os
import sys
Expand All @@ -17,11 +95,6 @@
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs

# This is a simple example of the use of the KV cache connector.
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
# See tensorrt_llm/_torch/pyexecutor/connector.py for details about the KV cache connector interface.
# NOTE: This example connector implementation is NOT suitable for production use.

CONNECTOR_CACHE_FOLDER_KEY = "CONNECTOR_CACHE_FOLDER"


Expand Down Expand Up @@ -198,6 +271,7 @@ def main(model: str):

this_module = __file__[__file__.rfind("/") + 1:__file__.rfind(".py")]

# --- KV Cache Connector Config ---
kv_connector_config = KvCacheConnectorConfig(
connector_module=this_module,
connector_scheduler_class="PersistentKvCacheConnectorLeader",
Expand All @@ -207,6 +281,7 @@ def main(model: str):
connector_cache_dir = TemporaryDirectory()
os.environ[CONNECTOR_CACHE_FOLDER_KEY] = connector_cache_dir.name

# Create LLM instance with KV Cache Connector
llm = LLM(model=model,
backend="pytorch",
cuda_graph_config=None,
Expand All @@ -220,6 +295,7 @@ def main(model: str):

sampling_params = SamplingParams(max_tokens=32)

# Generate text with the first LLM instance and save the kv cache blocks by the connector.
output = llm.generate([test_text], sampling_params)
text0 = output[0].outputs[0].text

Expand All @@ -228,16 +304,19 @@ def main(model: str):

del llm

# Create a new LLM instance with the same connector configuration
llm = LLM(model=model,
backend="pytorch",
cuda_graph_config=None,
kv_connector_config=kv_connector_config)

# Generate text with the second LLM instance and it should reuse the kv cache blocks from the connector.
output = llm.generate([test_text], sampling_params)
text1 = output[0].outputs[0].text

print("Second output (using connector cache): ", text1)

# Verify that the two outputs are identical
assert text0 == text1

connector_cache_dir.cleanup()
Expand Down
Loading