11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4- from unittest .mock import Mock
5-
64import torch
75
8- from vllm .v1 .attention .backends .flash_attn import (
9- FlashAttentionBackend , FlashAttentionMetadataBuilder )
10- from vllm .v1 .attention .backends .flex_attention import (
11- FlexAttentionBackend , FlexAttentionMetadataBuilder )
126from vllm .v1 .kv_cache_interface import FullAttentionSpec , KVCacheGroupSpec
13- from vllm .v1 .worker .utils import (AttentionGroup ,
14- initialize_kv_cache_for_kv_sharing )
7+ from vllm .v1 .worker .utils import add_kv_sharing_layers_to_kv_cache_groups
158
169
1710def new_kv_cache_spec ():
@@ -37,56 +30,17 @@ def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
3730 new_kv_cache_spec ()),
3831 ]
3932
40- attn_groups = [
41- # KV cache group 0 has two attention groups
42- [
43- AttentionGroup (
44- backend = FlashAttentionBackend ,
45- metadata_builder = Mock (spec = FlashAttentionMetadataBuilder ),
46- layer_names = ["model.layers.0" ],
47- ),
48- AttentionGroup (
49- backend = FlexAttentionBackend ,
50- metadata_builder = Mock (spec = FlexAttentionMetadataBuilder ),
51- layer_names = ["model.layers.1" ],
52- ),
53- ],
54- ]
55-
56- # Only layers 0 and 1 will have KV caches allocated
57- kv_caches = {
58- "model.layers.0" : torch .zeros (1 , 2 , 3 ),
59- "model.layers.1" : torch .ones (1 , 2 , 3 ),
60- }
61-
62- initialize_kv_cache_for_kv_sharing (
33+ add_kv_sharing_layers_to_kv_cache_groups (
6334 shared_kv_cache_layers = shared_kv_cache_layers ,
6435 kv_cache_groups = kv_cache_groups ,
65- kv_caches = kv_caches ,
66- attn_groups = attn_groups ,
6736 )
6837
69- # Check that the KV caches were shared correctly
70- assert kv_caches ["model.layers.2" ].data_ptr (
71- ) == kv_caches ["model.layers.0" ].data_ptr ()
72- assert kv_caches ["model.layers.3" ].data_ptr (
73- ) == kv_caches ["model.layers.1" ].data_ptr ()
74-
7538 # Check that the layers were added to the correct KV cache group
7639 assert len (kv_cache_groups ) == 1
7740 assert kv_cache_groups [0 ].layer_names == [
7841 "model.layers.0" , "model.layers.1" , "model.layers.2" , "model.layers.3"
7942 ]
8043
81- # Check that the layers were added to the attention groups
82- assert len (attn_groups ) == 1 and len (attn_groups [0 ]) == 2
83- assert attn_groups [0 ][0 ].layer_names == [
84- "model.layers.0" , "model.layers.2"
85- ]
86- assert attn_groups [0 ][1 ].layer_names == [
87- "model.layers.1" , "model.layers.3"
88- ]
89-
9044
9145def test_initialize_kv_cache_for_kv_sharing_same_attn_groups ():
9246 """
@@ -103,48 +57,17 @@ def test_initialize_kv_cache_for_kv_sharing_same_attn_groups():
10357 new_kv_cache_spec ()),
10458 ]
10559
106- attn_groups = [
107- # KV cache group 0 has a single attention group
108- # as all layers have the same flash attention backend
109- [
110- AttentionGroup (
111- backend = FlashAttentionBackend ,
112- metadata_builder = Mock (spec = FlashAttentionMetadataBuilder ),
113- layer_names = ["model.layers.0" , "model.layers.1" ],
114- ),
115- ],
116- ]
117-
118- kv_caches = {
119- "model.layers.0" : torch .zeros (1 , 2 , 3 ),
120- "model.layers.1" : torch .ones (1 , 2 , 3 ),
121- }
122-
123- initialize_kv_cache_for_kv_sharing (
60+ add_kv_sharing_layers_to_kv_cache_groups (
12461 shared_kv_cache_layers = shared_kv_cache_layers ,
12562 kv_cache_groups = kv_cache_groups ,
126- kv_caches = kv_caches ,
127- attn_groups = attn_groups ,
12863 )
12964
130- # Check that the KV caches were shared correctly
131- assert kv_caches ["model.layers.2" ].data_ptr (
132- ) == kv_caches ["model.layers.0" ].data_ptr ()
133- assert kv_caches ["model.layers.3" ].data_ptr (
134- ) == kv_caches ["model.layers.1" ].data_ptr ()
135-
13665 # Check that the layers were added to the correct KV cache group
13766 assert len (kv_cache_groups ) == 1
13867 assert kv_cache_groups [0 ].layer_names == [
13968 "model.layers.0" , "model.layers.1" , "model.layers.2" , "model.layers.3"
14069 ]
14170
142- # Check that the layers were added to the attention groups
143- assert len (attn_groups ) == 1 and len (attn_groups [0 ]) == 1
144- assert attn_groups [0 ][0 ].layer_names == [
145- "model.layers.0" , "model.layers.1" , "model.layers.2" , "model.layers.3"
146- ]
147-
14871
14972def test_initialize_kv_cache_for_kv_sharing_no_attn_groups ():
15073 """
@@ -162,23 +85,11 @@ def test_initialize_kv_cache_for_kv_sharing_no_attn_groups():
16285 KVCacheGroupSpec (["model.layers.1" ], new_kv_cache_spec ()),
16386 ]
16487
165- kv_caches = {
166- "model.layers.0" : torch .zeros (1 , 2 , 3 ),
167- "model.layers.1" : torch .ones (1 , 2 , 3 ),
168- }
169-
170- initialize_kv_cache_for_kv_sharing (
88+ add_kv_sharing_layers_to_kv_cache_groups (
17189 shared_kv_cache_layers = shared_kv_cache_layers ,
17290 kv_cache_groups = kv_cache_groups ,
173- kv_caches = kv_caches ,
17491 )
17592
176- # Check that the KV caches were shared correctly
177- assert kv_caches ["model.layers.2" ].data_ptr (
178- ) == kv_caches ["model.layers.0" ].data_ptr ()
179- assert kv_caches ["model.layers.3" ].data_ptr (
180- ) == kv_caches ["model.layers.1" ].data_ptr ()
181-
18293 # Check that the layers were added to the correct KV cache group
18394 assert len (kv_cache_groups ) == 2
18495 assert kv_cache_groups [0 ].layer_names == [
0 commit comments