Skip to content

Commit 7664abb

Browse files
committed
Add specific AD configs for nano-v3
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 11132fe commit 7664abb

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

examples/auto_deploy/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ benchmark_results.json
55
# ignore config files that users might put here for debugging
66
*.yaml
77
!nano_v3.yaml
8+
!nano_v3_accuracy.yaml
9+
!nano_v3_bench.yaml
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
runtime: trtllm
2+
compile_backend: torch-cudagraph
3+
max_batch_size: 128
4+
max_seq_len: 204800
5+
enable_chunked_prefill: true
6+
attn_backend: flashinfer
7+
model_factory: AutoModelForCausalLM
8+
skip_loading_weights: false
9+
free_mem_ratio: 0.9
10+
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128]
11+
kv_cache_config:
12+
# disable kv_cache reuse since not supported for hybrid/ssm models
13+
enable_block_reuse: false
14+
transforms:
15+
detect_sharding:
16+
sharding_source: ['factory', 'heuristic']
17+
sharding_dims: ['ep', 'bmm']
18+
# tunable mamba cache dtype
19+
# --> use float32 for accuracy and default (null) for speed
20+
insert_cached_ssm_attention:
21+
cache_config:
22+
mamba_dtype: float32
23+
# mamba_dtype: null
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
runtime: trtllm
2+
compile_backend: torch-cudagraph
3+
max_batch_size: 384 # tunable
4+
max_seq_len: 65536 # tunable
5+
enable_chunked_prefill: true
6+
attn_backend: flashinfer
7+
model_factory: AutoModelForCausalLM
8+
skip_loading_weights: false
9+
free_mem_ratio: 0.9
10+
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64, 128, 256, 320, 384]
11+
kv_cache_config:
12+
# disable kv_cache reuse since not supported for hybrid/ssm models
13+
enable_block_reuse: false
14+
transforms:
15+
detect_sharding:
16+
sharding_source: ['factory', 'heuristic']
17+
sharding_dims: ['ep', 'bmm']
18+
# tunable mamba cache dtype
19+
# --> use float32 for accuracy and default (null) for speed
20+
insert_cached_ssm_attention:
21+
cache_config:
22+
# mamba_dtype: float32
23+
mamba_dtype: null

0 commit comments

Comments
 (0)