11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ import hashlib
5+ import pickle
46import tempfile
57from contextlib import contextmanager
8+ from unittest .mock import Mock , patch
69
710import pytest
811import torch
912
10- from vllm .compilation .decorators import support_torch_compile
13+ from vllm .compilation .caching import VllmSerializableFunction
14+ from vllm .compilation .decorators import save_compile_cache , support_torch_compile
1115from vllm .config import (
1216 CompilationConfig ,
1317 CompilationMode ,
@@ -39,6 +43,7 @@ def make_vllm_config() -> VllmConfig:
3943 return VllmConfig (
4044 compilation_config = CompilationConfig (
4145 mode = CompilationMode .VLLM_COMPILE ,
46+ use_inductor = True ,
4247 )
4348 )
4449
@@ -59,6 +64,8 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
5964 expected = reference_fn (* args )
6065 with use_vllm_config (vllm_config ):
6166 m .setenv ("VLLM_USE_AOT_COMPILE" , "0" )
67+ m .setenv ("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS" , "1" )
68+ m .setenv ("VLLM_USE_STANDALONE_COMPILE" , "1" )
6269 with (
6370 pytest .raises (RuntimeError , match = "Detected recompile" ),
6471 torch .compiler .set_stance ("fail_on_recompile" ),
@@ -79,6 +86,8 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
7986 with tempfile .TemporaryDirectory () as tmpdirname , monkeypatch .context () as m :
8087 args = (torch .randn (10 , 10 ),)
8188 m .setenv ("VLLM_USE_AOT_COMPILE" , "1" )
89+ m .setenv ("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS" , "1" )
90+ m .setenv ("VLLM_USE_STANDALONE_COMPILE" , "1" )
8291 m .setenv ("VLLM_FORCE_AOT_LOAD" , "1" )
8392 m .setenv ("VLLM_CACHE_ROOT" , tmpdirname )
8493 vllm_config = make_vllm_config ()
@@ -96,9 +105,13 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
96105 with tempfile .TemporaryDirectory () as tmpdirname :
97106 m .setenv ("VLLM_CACHE_ROOT" , tmpdirname )
98107 m .setenv ("VLLM_USE_AOT_COMPILE" , "1" )
108+ m .setenv ("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS" , "1" )
109+ m .setenv ("VLLM_USE_STANDALONE_COMPILE" , "1" )
99110 vllm_config = make_vllm_config ()
100111 with use_vllm_config (vllm_config ):
101- expected = CompiledMod (vllm_config = vllm_config )(* args )
112+ compiled_mod = CompiledMod (vllm_config = vllm_config )
113+ expected = compiled_mod (* args )
114+ save_compile_cache (compiled_mod )
102115
103116 m .setenv ("VLLM_FORCE_AOT_LOAD" , "1" )
104117 vllm_config = make_vllm_config ()
@@ -121,13 +134,16 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
121134 with tempfile .TemporaryDirectory () as tmpdirname :
122135 m .setenv ("VLLM_CACHE_ROOT" , tmpdirname )
123136 m .setenv ("VLLM_USE_AOT_COMPILE" , "1" )
137+ m .setenv ("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS" , "1" )
138+ m .setenv ("VLLM_USE_STANDALONE_COMPILE" , "1" )
124139 vllm_config = make_vllm_config ()
125140 with use_vllm_config (vllm_config ):
126141 compiled_mod = CompiledMod (vllm_config = vllm_config )
127142 compiled_mod (* args )
128143 artifacts = compiled_mod .aot_compiled_fn ._artifacts
129144 guards_string = artifacts .compiled_fn .shape_env .format_guards ()
130145 assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
146+ save_compile_cache (compiled_mod )
131147
132148 m .setenv ("VLLM_FORCE_AOT_LOAD" , "1" )
133149 vllm_config = make_vllm_config ()
@@ -137,3 +153,227 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
137153 artifacts = compiled_mod .aot_compiled_fn ._artifacts
138154 guards_string = artifacts .compiled_fn .shape_env .format_guards ()
139155 assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
156+
157+
158+ class TestInductorCompiledArtifacts :
159+ def test_init (self ):
160+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
161+ assert cache .submodule_bytes == {}
162+ assert cache .submodule_bytes_store == {}
163+ assert cache .loaded_submodule_store == {}
164+
165+ def test_insert_new_artifact (self ):
166+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
167+ test_data = b"test_artifact_data"
168+ submod_name = "test_submod"
169+ shape = "s1"
170+
171+ hasher = hashlib .sha256 ()
172+ hasher .update (test_data )
173+ expected_hash = hasher .hexdigest ()
174+
175+ cache .insert (submod_name , shape , test_data )
176+
177+ assert f"{ submod_name } _{ shape } " in cache .submodule_bytes
178+ assert cache .submodule_bytes [f"{ submod_name } _{ shape } " ] == expected_hash
179+ assert expected_hash in cache .submodule_bytes_store
180+ assert cache .submodule_bytes_store [expected_hash ] == test_data
181+
182+ def test_insert_duplicate_artifact (self ):
183+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
184+
185+ test_data = b"duplicate_test_data"
186+ submod_name1 = "submod1"
187+ submod_name2 = "submod2"
188+ shape = "s2"
189+
190+ cache .insert (submod_name1 , shape , test_data )
191+ cache .insert (submod_name2 , shape , test_data )
192+
193+ hash1 = cache .submodule_bytes [f"{ submod_name1 } _{ shape } " ]
194+ hash2 = cache .submodule_bytes [f"{ submod_name2 } _{ shape } " ]
195+ assert hash1 == hash2
196+
197+ assert len (cache .submodule_bytes_store ) == 1
198+ assert len (cache .submodule_bytes ) == 2
199+
200+ def test_get_artifact (self ):
201+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
202+ test_data = b"retrievable_data"
203+ submod_name = "mod1"
204+ shape = "shape16"
205+
206+ cache .insert (submod_name , shape , test_data )
207+ retrieved_data = cache .get (submod_name , shape )
208+
209+ assert retrieved_data == test_data
210+
211+ def test_get_nonexistent_artifact (self ):
212+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
213+
214+ with pytest .raises (KeyError ):
215+ cache .get ("nonexistent" , "shape" )
216+
217+ def test_size_bytes (self ):
218+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
219+
220+ assert cache .size_bytes () == 0
221+
222+ data1 = b"x" * 100
223+ data2 = b"y" * 200
224+ cache .insert ("mod1" , "shape1" , data1 )
225+ cache .insert ("mod2" , "shape2" , data2 )
226+
227+ assert cache .size_bytes () == 300
228+
229+ def test_num_artifacts_and_entries (self ):
230+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
231+
232+ assert cache .num_artifacts () == 0
233+ assert cache .num_entries () == 0
234+
235+ cache .insert ("mod1" , "shape1" , b"data1" )
236+ cache .insert ("mod2" , "shape2" , b"data2" )
237+ assert cache .num_artifacts () == 2
238+ assert cache .num_entries () == 2
239+
240+ cache .insert ("mod3" , "shape3" , b"data1" )
241+ assert cache .num_artifacts () == 2
242+ assert cache .num_entries () == 3
243+
244+ @patch ("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize" )
245+ def test_load_all_success (self , mock_deserialize ):
246+ """Test successful loading of all artifacts"""
247+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
248+
249+ mock_artifact1 = Mock ()
250+ mock_artifact2 = Mock ()
251+ mock_deserialize .side_effect = [mock_artifact1 , mock_artifact2 ]
252+
253+ cache .insert ("mod1" , "shape1" , pickle .dumps (b"data1" ))
254+ cache .insert ("mod2" , "shape2" , pickle .dumps (b"data2" ))
255+
256+ cache .load_all ()
257+
258+ assert len (cache .loaded_submodule_store ) == 2
259+ assert mock_deserialize .call_count == 2
260+
261+ @patch ("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize" )
262+ def test_load_all_already_loaded (self , mock_deserialize ):
263+ """Test that load_all skips if already loaded"""
264+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
265+
266+ mock_artifact = Mock ()
267+ cache .submodule_bytes_store ["hash1" ] = pickle .dumps (b"data1" )
268+ cache .loaded_submodule_store ["hash1" ] = mock_artifact
269+
270+ cache .load_all ()
271+
272+ mock_deserialize .assert_not_called ()
273+
274+ @patch ("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize" )
275+ def test_get_loaded_artifact (self , mock_deserialize ):
276+ """Test retrieving loaded artifacts"""
277+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
278+
279+ mock_artifact = Mock ()
280+ mock_deserialize .return_value = mock_artifact
281+
282+ submod_name = "test_mod"
283+ shape = "test_shape"
284+ cache .insert (submod_name , shape , pickle .dumps (b"test_data" ))
285+ cache .load_all ()
286+
287+ retrieved_artifact = cache .get_loaded (submod_name , shape )
288+ assert retrieved_artifact == mock_artifact
289+
290+ def test_getstate_setstate (self ):
291+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
292+
293+ cache .insert ("mod1" , "shape1" , b"data1" )
294+ cache .insert ("mod2" , "shape2" , b"data2" )
295+
296+ cache .loaded_submodule_store ["hash1" ] = Mock ()
297+
298+ state = cache .__getstate__ ()
299+
300+ assert "submodule_bytes" in state
301+ assert "submodule_bytes_store" in state
302+ assert "loaded_submodule_store" not in state
303+
304+ new_cache = VllmSerializableFunction .InductorCompiledArtifacts ()
305+ new_cache .__setstate__ (state )
306+
307+ assert new_cache .submodule_bytes == cache .submodule_bytes
308+ assert new_cache .submodule_bytes_store == cache .submodule_bytes_store
309+ assert new_cache .loaded_submodule_store == {}
310+
311+ def test_pickle_roundtrip (self ):
312+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
313+
314+ test_data1 = b"pickle_test_data_1"
315+ test_data2 = b"pickle_test_data_2"
316+ cache .insert ("mod1" , "shape1" , test_data1 )
317+ cache .insert ("mod2" , "shape2" , test_data2 )
318+
319+ pickled_data = pickle .dumps (cache )
320+ restored_cache = pickle .loads (pickled_data )
321+
322+ assert restored_cache .get ("mod1" , "shape1" ) == test_data1
323+ assert restored_cache .get ("mod2" , "shape2" ) == test_data2
324+ assert restored_cache .num_artifacts () == cache .num_artifacts ()
325+ assert restored_cache .num_entries () == cache .num_entries ()
326+ assert restored_cache .size_bytes () == cache .size_bytes ()
327+
328+ assert len (restored_cache .loaded_submodule_store ) == 0
329+
330+
331+ class TestInductorCompiledArtifactsIntegration :
332+ def test_add_pickle_unpickle (self ):
333+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
334+
335+ artifacts = {
336+ ("mod1" , "shape1" ): b"m1s1_artifact" ,
337+ ("mod1" , "shape2" ): b"m1s2_artifact" ,
338+ ("mod2" , "shape1" ): b"m2s1_artifact" ,
339+ ("mod2" , "shape2" ): b"m2s2_artifact" ,
340+ }
341+
342+ for (submod , shape ), data in artifacts .items ():
343+ cache .insert (submod , shape , data )
344+
345+ assert cache .num_entries () == 4
346+ assert cache .num_artifacts () == 4
347+
348+ for (submod , shape ), expected_data in artifacts .items ():
349+ retrieved_data = cache .get (submod , shape )
350+ assert retrieved_data == expected_data
351+
352+ pickled = pickle .dumps (cache )
353+ restored_cache = pickle .loads (pickled )
354+
355+ for (submod , shape ), expected_data in artifacts .items ():
356+ retrieved_data = restored_cache .get (submod , shape )
357+ assert retrieved_data == expected_data
358+
359+ def test_deduplication (self ):
360+ cache = VllmSerializableFunction .InductorCompiledArtifacts ()
361+
362+ shared_data = b"shared_artifact_data" * 1000
363+
364+ cache .insert ("mod1" , "shape1" , shared_data )
365+ cache .insert ("mod2" , "shape1" , shared_data )
366+ cache .insert ("mod1" , "shape2" , shared_data )
367+ cache .insert ("mod3" , "shape3" , shared_data )
368+
369+ assert cache .num_entries () == 4
370+ assert cache .num_artifacts () == 1
371+ assert cache .size_bytes () == len (shared_data )
372+
373+ for submod , shape in [
374+ ("mod1" , "shape1" ),
375+ ("mod2" , "shape1" ),
376+ ("mod1" , "shape2" ),
377+ ("mod3" , "shape3" ),
378+ ]:
379+ assert cache .get (submod , shape ) == shared_data
0 commit comments