Skip to content

Commit 03fa8b3

Browse files
committed
serialize inductor artifacts
1 parent 3873813 commit 03fa8b3

File tree

7 files changed

+724
-106
lines changed

7 files changed

+724
-106
lines changed

tests/compile/test_aot_compile.py

Lines changed: 250 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import hashlib
5+
import pickle
46
import tempfile
57
from contextlib import contextmanager
8+
from unittest.mock import Mock, patch
69

710
import pytest
811
import torch
912

13+
from vllm.compilation.backends import VllmCompiledFunction
1014
from vllm.compilation.decorators import support_torch_compile
1115
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
1216
set_current_vllm_config)
@@ -91,7 +95,251 @@ def test_basic(monkeypatch: pytest.MonkeyPatch):
9195
m.setenv("VLLM_USE_AOT_COMPILE", "1")
9296
vllm_config = make_vllm_config()
9397
with use_vllm_config(vllm_config):
94-
expected = CompiledMod(vllm_config=vllm_config)(*args)
98+
mod = CompiledMod(vllm_config=vllm_config)
99+
expected = mod(*args)
95100
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
96-
ret = CompiledMod(vllm_config=vllm_config)(*args)
101+
ret = mod(*args)
97102
assert torch.allclose(ret, expected)
103+
104+
105+
class TestInductorCache:
106+
107+
def test_init(self):
108+
cache = VllmCompiledFunction.InductorCache()
109+
assert cache.submodule_bytes == {}
110+
assert cache.submodule_bytes_store == {}
111+
assert cache.loaded_submodule_store == {}
112+
113+
def test_insert_new_artifact(self):
114+
cache = VllmCompiledFunction.InductorCache()
115+
test_data = b"test_artifact_data"
116+
submod_name = "test_submod"
117+
shape = "s1"
118+
119+
hasher = hashlib.sha256()
120+
hasher.update(test_data)
121+
expected_hash = hasher.hexdigest()
122+
123+
cache.insert(submod_name, shape, test_data)
124+
125+
assert f"{submod_name}_{shape}" in cache.submodule_bytes
126+
assert cache.submodule_bytes[f"{submod_name}_{shape}"] == expected_hash
127+
assert expected_hash in cache.submodule_bytes_store
128+
assert cache.submodule_bytes_store[expected_hash] == test_data
129+
130+
def test_insert_duplicate_artifact(self):
131+
cache = VllmCompiledFunction.InductorCache()
132+
133+
test_data = b"duplicate_test_data"
134+
submod_name1 = "submod1"
135+
submod_name2 = "submod2"
136+
shape = "s2"
137+
138+
cache.insert(submod_name1, shape, test_data)
139+
cache.insert(submod_name2, shape, test_data)
140+
141+
hash1 = cache.submodule_bytes[f"{submod_name1}_{shape}"]
142+
hash2 = cache.submodule_bytes[f"{submod_name2}_{shape}"]
143+
assert hash1 == hash2
144+
145+
assert len(cache.submodule_bytes_store) == 1
146+
assert len(cache.submodule_bytes) == 2
147+
148+
def test_get_artifact(self):
149+
150+
cache = VllmCompiledFunction.InductorCache()
151+
test_data = b"retrievable_data"
152+
submod_name = "mod1"
153+
shape = "shape16"
154+
155+
cache.insert(submod_name, shape, test_data)
156+
retrieved_data = cache.get(submod_name, shape)
157+
158+
assert retrieved_data == test_data
159+
160+
def test_get_nonexistent_artifact(self):
161+
cache = VllmCompiledFunction.InductorCache()
162+
163+
with pytest.raises(KeyError):
164+
cache.get("nonexistent", "shape")
165+
166+
def test_size_bytes(self):
167+
cache = VllmCompiledFunction.InductorCache()
168+
169+
assert cache.size_bytes() == 0
170+
171+
data1 = b"x" * 100
172+
data2 = b"y" * 200
173+
cache.insert("mod1", "shape1", data1)
174+
cache.insert("mod2", "shape2", data2)
175+
176+
assert cache.size_bytes() == 300
177+
178+
def test_num_artifacts_and_entries(self):
179+
cache = VllmCompiledFunction.InductorCache()
180+
181+
assert cache.num_artifacts() == 0
182+
assert cache.num_entries() == 0
183+
184+
cache.insert("mod1", "shape1", b"data1")
185+
cache.insert("mod2", "shape2", b"data2")
186+
assert cache.num_artifacts() == 2
187+
assert cache.num_entries() == 2
188+
189+
cache.insert("mod3", "shape3", b"data1")
190+
assert cache.num_artifacts() == 2
191+
assert cache.num_entries() == 3
192+
193+
@patch("torch._inductor.CompiledArtifact.from_bytes")
194+
def test_load_all_success(self, mock_from_bytes):
195+
"""Test successful loading of all artifacts"""
196+
cache = VllmCompiledFunction.InductorCache()
197+
198+
mock_artifact1 = Mock()
199+
mock_artifact2 = Mock()
200+
mock_from_bytes.side_effect = [mock_artifact1, mock_artifact2]
201+
202+
cache.insert("mod1", "shape1", b"data1")
203+
cache.insert("mod2", "shape2", b"data2")
204+
205+
cache.load_all()
206+
207+
assert len(cache.loaded_submodule_store) == 2
208+
assert mock_from_bytes.call_count == 2
209+
210+
@patch("torch._inductor.CompiledArtifact.from_bytes")
211+
def test_load_all_with_retry(self, mock_from_bytes):
212+
"""Test loading with retries on initial failure"""
213+
cache = VllmCompiledFunction.InductorCache()
214+
215+
mock_artifact = Mock()
216+
mock_from_bytes.side_effect = [
217+
Exception("First attempt fails"), mock_artifact
218+
]
219+
220+
cache.insert("mod1", "shape1", b"data1")
221+
cache.load_all()
222+
223+
assert len(cache.loaded_submodule_store) == 1
224+
assert mock_from_bytes.call_count == 2
225+
226+
@patch("torch._inductor.CompiledArtifact.from_bytes")
227+
def test_load_all_already_loaded(self, mock_from_bytes):
228+
"""Test that load_all skips if already loaded"""
229+
cache = VllmCompiledFunction.InductorCache()
230+
231+
mock_artifact = Mock()
232+
cache.submodule_bytes_store["hash1"] = b"data1"
233+
cache.loaded_submodule_store["hash1"] = mock_artifact
234+
235+
cache.load_all()
236+
237+
mock_from_bytes.assert_not_called()
238+
239+
@patch("torch._inductor.CompiledArtifact.from_bytes")
240+
def test_get_loaded_artifact(self, mock_from_bytes):
241+
"""Test retrieving loaded artifacts"""
242+
cache = VllmCompiledFunction.InductorCache()
243+
244+
mock_artifact = Mock()
245+
mock_from_bytes.return_value = mock_artifact
246+
247+
submod_name = "test_mod"
248+
shape = "test_shape"
249+
cache.insert(submod_name, shape, b"test_data")
250+
cache.load_all()
251+
252+
retrieved_artifact = cache.get_loaded(submod_name, shape)
253+
assert retrieved_artifact == mock_artifact
254+
255+
def test_getstate_setstate(self):
256+
cache = VllmCompiledFunction.InductorCache()
257+
258+
cache.insert("mod1", "shape1", b"data1")
259+
cache.insert("mod2", "shape2", b"data2")
260+
261+
cache.loaded_submodule_store["hash1"] = Mock()
262+
263+
state = cache.__getstate__()
264+
265+
assert "submodule_bytes" in state
266+
assert "submodule_bytes_store" in state
267+
assert "loaded_submodule_store" not in state
268+
269+
new_cache = VllmCompiledFunction.InductorCache()
270+
new_cache.__setstate__(state)
271+
272+
assert new_cache.submodule_bytes == cache.submodule_bytes
273+
assert new_cache.submodule_bytes_store == cache.submodule_bytes_store
274+
assert new_cache.loaded_submodule_store == {}
275+
276+
def test_pickle_roundtrip(self):
277+
cache = VllmCompiledFunction.InductorCache()
278+
279+
test_data1 = b"pickle_test_data_1"
280+
test_data2 = b"pickle_test_data_2"
281+
cache.insert("mod1", "shape1", test_data1)
282+
cache.insert("mod2", "shape2", test_data2)
283+
284+
pickled_data = pickle.dumps(cache)
285+
restored_cache = pickle.loads(pickled_data)
286+
287+
assert restored_cache.get("mod1", "shape1") == test_data1
288+
assert restored_cache.get("mod2", "shape2") == test_data2
289+
assert restored_cache.num_artifacts() == cache.num_artifacts()
290+
assert restored_cache.num_entries() == cache.num_entries()
291+
assert restored_cache.size_bytes() == cache.size_bytes()
292+
293+
assert len(restored_cache.loaded_submodule_store) == 0
294+
295+
296+
class TestInductorCacheIntegration:
297+
298+
def test_add_pickle_unpickle(self):
299+
cache = VllmCompiledFunction.InductorCache()
300+
301+
artifacts = {
302+
("mod1", "shape1"): b"m1s1_artifact",
303+
("mod1", "shape2"): b"m1s2_artifact",
304+
("mod2", "shape1"): b"m2s1_artifact",
305+
("mod2", "shape2"): b"m2s2_artifact",
306+
}
307+
308+
for (submod, shape), data in artifacts.items():
309+
cache.insert(submod, shape, data)
310+
311+
assert cache.num_entries() == 4
312+
assert cache.num_artifacts() == 4
313+
314+
for (submod, shape), expected_data in artifacts.items():
315+
retrieved_data = cache.get(submod, shape)
316+
assert retrieved_data == expected_data
317+
318+
pickled = pickle.dumps(cache)
319+
restored_cache = pickle.loads(pickled)
320+
321+
for (submod, shape), expected_data in artifacts.items():
322+
retrieved_data = restored_cache.get(submod, shape)
323+
assert retrieved_data == expected_data
324+
325+
def test_deduplication(self):
326+
cache = VllmCompiledFunction.InductorCache()
327+
328+
shared_data = b"shared_artifact_data" * 1000
329+
330+
cache.insert("mod1", "shape1", shared_data)
331+
cache.insert("mod2", "shape1", shared_data)
332+
cache.insert("mod1", "shape2", shared_data)
333+
cache.insert("mod3", "shape3", shared_data)
334+
335+
assert cache.num_entries() == 4
336+
assert cache.num_artifacts() == 1
337+
assert cache.size_bytes() == len(shared_data)
338+
339+
for submod, shape in [
340+
("mod1", "shape1"),
341+
("mod2", "shape1"),
342+
("mod1", "shape2"),
343+
("mod3", "shape3"),
344+
]:
345+
assert cache.get(submod, shape) == shared_data

0 commit comments

Comments
 (0)