Skip to content

Commit fca965f

Browse files
committed
serialize inductor artifacts
Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
1 parent 69f0640 commit fca965f

File tree

13 files changed

+2062
-117
lines changed

13 files changed

+2062
-117
lines changed

benchmarks/compile/benchmark_inductor_compiled_artifacts.py

Lines changed: 659 additions & 0 deletions
Large diffs are not rendered by default.

tests/compile/test_aot_compile.py

Lines changed: 242 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import hashlib
5+
import pickle
46
import tempfile
57
from contextlib import contextmanager
8+
from unittest.mock import Mock, patch
69

710
import pytest
811
import torch
912

10-
from vllm.compilation.decorators import support_torch_compile
13+
from vllm.compilation.caching import VllmSerializableFunction
14+
from vllm.compilation.decorators import save_compile_cache, support_torch_compile
1115
from vllm.config import (
1216
CompilationConfig,
1317
CompilationMode,
@@ -39,6 +43,7 @@ def make_vllm_config() -> VllmConfig:
3943
return VllmConfig(
4044
compilation_config=CompilationConfig(
4145
mode=CompilationMode.VLLM_COMPILE,
46+
backend="inductor",
4247
)
4348
)
4449

@@ -59,6 +64,8 @@ def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
5964
expected = reference_fn(*args)
6065
with use_vllm_config(vllm_config):
6166
m.setenv("VLLM_USE_AOT_COMPILE", "0")
67+
m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1")
68+
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
6269
with (
6370
pytest.raises(RuntimeError, match="Detected recompile"),
6471
torch.compiler.set_stance("fail_on_recompile"),
@@ -79,6 +86,8 @@ def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
7986
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
8087
args = (torch.randn(10, 10),)
8188
m.setenv("VLLM_USE_AOT_COMPILE", "1")
89+
m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1")
90+
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
8291
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
8392
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
8493
vllm_config = make_vllm_config()
@@ -96,9 +105,13 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
96105
with tempfile.TemporaryDirectory() as tmpdirname:
97106
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
98107
m.setenv("VLLM_USE_AOT_COMPILE", "1")
108+
m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1")
109+
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
99110
vllm_config = make_vllm_config()
100111
with use_vllm_config(vllm_config):
101-
expected = CompiledMod(vllm_config=vllm_config)(*args)
112+
compiled_mod = CompiledMod(vllm_config=vllm_config)
113+
expected = compiled_mod(*args)
114+
save_compile_cache(compiled_mod)
102115

103116
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
104117
vllm_config = make_vllm_config()
@@ -121,13 +134,16 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
121134
with tempfile.TemporaryDirectory() as tmpdirname:
122135
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
123136
m.setenv("VLLM_USE_AOT_COMPILE", "1")
137+
m.setenv("VLLM_USE_BACKEND_WITH_INDUCTOR_COMPILED_ARTIFACTS", "1")
138+
m.setenv("VLLM_USE_STANDALONE_COMPILE", "1")
124139
vllm_config = make_vllm_config()
125140
with use_vllm_config(vllm_config):
126141
compiled_mod = CompiledMod(vllm_config=vllm_config)
127142
compiled_mod(*args)
128143
artifacts = compiled_mod.aot_compiled_fn._artifacts
129144
guards_string = artifacts.compiled_fn.shape_env.format_guards()
130145
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
146+
save_compile_cache(compiled_mod)
131147

132148
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
133149
vllm_config = make_vllm_config()
@@ -137,3 +153,227 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
137153
artifacts = compiled_mod.aot_compiled_fn._artifacts
138154
guards_string = artifacts.compiled_fn.shape_env.format_guards()
139155
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
156+
157+
158+
class TestInductorCompiledArtifacts:
159+
def test_init(self):
160+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
161+
assert cache.submodule_bytes == {}
162+
assert cache.submodule_bytes_store == {}
163+
assert cache.loaded_submodule_store == {}
164+
165+
def test_insert_new_artifact(self):
166+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
167+
test_data = b"test_artifact_data"
168+
submod_name = "test_submod"
169+
shape = "s1"
170+
171+
hasher = hashlib.sha256()
172+
hasher.update(test_data)
173+
expected_hash = hasher.hexdigest()
174+
175+
cache.insert(submod_name, shape, test_data)
176+
177+
assert f"{submod_name}_{shape}" in cache.submodule_bytes
178+
assert cache.submodule_bytes[f"{submod_name}_{shape}"] == expected_hash
179+
assert expected_hash in cache.submodule_bytes_store
180+
assert cache.submodule_bytes_store[expected_hash] == test_data
181+
182+
def test_insert_duplicate_artifact(self):
183+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
184+
185+
test_data = b"duplicate_test_data"
186+
submod_name1 = "submod1"
187+
submod_name2 = "submod2"
188+
shape = "s2"
189+
190+
cache.insert(submod_name1, shape, test_data)
191+
cache.insert(submod_name2, shape, test_data)
192+
193+
hash1 = cache.submodule_bytes[f"{submod_name1}_{shape}"]
194+
hash2 = cache.submodule_bytes[f"{submod_name2}_{shape}"]
195+
assert hash1 == hash2
196+
197+
assert len(cache.submodule_bytes_store) == 1
198+
assert len(cache.submodule_bytes) == 2
199+
200+
def test_get_artifact(self):
201+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
202+
test_data = b"retrievable_data"
203+
submod_name = "mod1"
204+
shape = "shape16"
205+
206+
cache.insert(submod_name, shape, test_data)
207+
retrieved_data = cache.get(submod_name, shape)
208+
209+
assert retrieved_data == test_data
210+
211+
def test_get_nonexistent_artifact(self):
212+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
213+
214+
with pytest.raises(KeyError):
215+
cache.get("nonexistent", "shape")
216+
217+
def test_size_bytes(self):
218+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
219+
220+
assert cache.size_bytes() == 0
221+
222+
data1 = b"x" * 100
223+
data2 = b"y" * 200
224+
cache.insert("mod1", "shape1", data1)
225+
cache.insert("mod2", "shape2", data2)
226+
227+
assert cache.size_bytes() == 300
228+
229+
def test_num_artifacts_and_entries(self):
230+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
231+
232+
assert cache.num_artifacts() == 0
233+
assert cache.num_entries() == 0
234+
235+
cache.insert("mod1", "shape1", b"data1")
236+
cache.insert("mod2", "shape2", b"data2")
237+
assert cache.num_artifacts() == 2
238+
assert cache.num_entries() == 2
239+
240+
cache.insert("mod3", "shape3", b"data1")
241+
assert cache.num_artifacts() == 2
242+
assert cache.num_entries() == 3
243+
244+
@patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize")
245+
def test_load_all_success(self, mock_deserialize):
246+
"""Test successful loading of all artifacts"""
247+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
248+
249+
mock_artifact1 = Mock()
250+
mock_artifact2 = Mock()
251+
mock_deserialize.side_effect = [mock_artifact1, mock_artifact2]
252+
253+
cache.insert("mod1", "shape1", pickle.dumps(b"data1"))
254+
cache.insert("mod2", "shape2", pickle.dumps(b"data2"))
255+
256+
cache.load_all()
257+
258+
assert len(cache.loaded_submodule_store) == 2
259+
assert mock_deserialize.call_count == 2
260+
261+
@patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize")
262+
def test_load_all_already_loaded(self, mock_deserialize):
263+
"""Test that load_all skips if already loaded"""
264+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
265+
266+
mock_artifact = Mock()
267+
cache.submodule_bytes_store["hash1"] = pickle.dumps(b"data1")
268+
cache.loaded_submodule_store["hash1"] = mock_artifact
269+
270+
cache.load_all()
271+
272+
mock_deserialize.assert_not_called()
273+
274+
@patch("torch._inductor.standalone_compile.AOTCompiledArtifact.deserialize")
275+
def test_get_loaded_artifact(self, mock_deserialize):
276+
"""Test retrieving loaded artifacts"""
277+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
278+
279+
mock_artifact = Mock()
280+
mock_deserialize.return_value = mock_artifact
281+
282+
submod_name = "test_mod"
283+
shape = "test_shape"
284+
cache.insert(submod_name, shape, pickle.dumps(b"test_data"))
285+
cache.load_all()
286+
287+
retrieved_artifact = cache.get_loaded(submod_name, shape)
288+
assert retrieved_artifact == mock_artifact
289+
290+
def test_getstate_setstate(self):
291+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
292+
293+
cache.insert("mod1", "shape1", b"data1")
294+
cache.insert("mod2", "shape2", b"data2")
295+
296+
cache.loaded_submodule_store["hash1"] = Mock()
297+
298+
state = cache.__getstate__()
299+
300+
assert "submodule_bytes" in state
301+
assert "submodule_bytes_store" in state
302+
assert "loaded_submodule_store" not in state
303+
304+
new_cache = VllmSerializableFunction.InductorCompiledArtifacts()
305+
new_cache.__setstate__(state)
306+
307+
assert new_cache.submodule_bytes == cache.submodule_bytes
308+
assert new_cache.submodule_bytes_store == cache.submodule_bytes_store
309+
assert new_cache.loaded_submodule_store == {}
310+
311+
def test_pickle_roundtrip(self):
312+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
313+
314+
test_data1 = b"pickle_test_data_1"
315+
test_data2 = b"pickle_test_data_2"
316+
cache.insert("mod1", "shape1", test_data1)
317+
cache.insert("mod2", "shape2", test_data2)
318+
319+
pickled_data = pickle.dumps(cache)
320+
restored_cache = pickle.loads(pickled_data)
321+
322+
assert restored_cache.get("mod1", "shape1") == test_data1
323+
assert restored_cache.get("mod2", "shape2") == test_data2
324+
assert restored_cache.num_artifacts() == cache.num_artifacts()
325+
assert restored_cache.num_entries() == cache.num_entries()
326+
assert restored_cache.size_bytes() == cache.size_bytes()
327+
328+
assert len(restored_cache.loaded_submodule_store) == 0
329+
330+
331+
class TestInductorCompiledArtifactsIntegration:
332+
def test_add_pickle_unpickle(self):
333+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
334+
335+
artifacts = {
336+
("mod1", "shape1"): b"m1s1_artifact",
337+
("mod1", "shape2"): b"m1s2_artifact",
338+
("mod2", "shape1"): b"m2s1_artifact",
339+
("mod2", "shape2"): b"m2s2_artifact",
340+
}
341+
342+
for (submod, shape), data in artifacts.items():
343+
cache.insert(submod, shape, data)
344+
345+
assert cache.num_entries() == 4
346+
assert cache.num_artifacts() == 4
347+
348+
for (submod, shape), expected_data in artifacts.items():
349+
retrieved_data = cache.get(submod, shape)
350+
assert retrieved_data == expected_data
351+
352+
pickled = pickle.dumps(cache)
353+
restored_cache = pickle.loads(pickled)
354+
355+
for (submod, shape), expected_data in artifacts.items():
356+
retrieved_data = restored_cache.get(submod, shape)
357+
assert retrieved_data == expected_data
358+
359+
def test_deduplication(self):
360+
cache = VllmSerializableFunction.InductorCompiledArtifacts()
361+
362+
shared_data = b"shared_artifact_data" * 1000
363+
364+
cache.insert("mod1", "shape1", shared_data)
365+
cache.insert("mod2", "shape1", shared_data)
366+
cache.insert("mod1", "shape2", shared_data)
367+
cache.insert("mod3", "shape3", shared_data)
368+
369+
assert cache.num_entries() == 4
370+
assert cache.num_artifacts() == 1
371+
assert cache.size_bytes() == len(shared_data)
372+
373+
for submod, shape in [
374+
("mod1", "shape1"),
375+
("mod2", "shape1"),
376+
("mod1", "shape2"),
377+
("mod3", "shape3"),
378+
]:
379+
assert cache.get(submod, shape) == shared_data

0 commit comments

Comments
 (0)