Skip to content

[CI Failure]: AssertionError on graph pickler while serializing symint nodes. #27348

@zhxchen17

Description

@zhxchen17

Name of failing test

tests/lora/test_quant_model.py

Basic information

  • Flaky test
  • Can reproduce locally
  • Caused by external libraries (e.g. bug in transformers)

🧪 Describe the failing test

reproducible with torch 2.10

(EngineCore_DP0 pid=4544) ERROR 10-15 10:01:02 [core.py:790] AssertionError
(EngineCore_DP0 pid=4544) Process EngineCore_DP0:
(EngineCore_DP0 pid=4544) Traceback (most recent call last):
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=4544)     self.run()
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=4544)     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 794, in run_engine_core
(EngineCore_DP0 pid=4544)     raise e
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 781, in run_engine_core
(EngineCore_DP0 pid=4544)     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=4544)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 553, in __init__
(EngineCore_DP0 pid=4544)     super().__init__(
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 110, in __init__
(EngineCore_DP0 pid=4544)     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=4544)                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 221, in _initialize_kv_caches
(EngineCore_DP0 pid=4544)     available_gpu_memory = self.model_executor.determine_available_memory()
(EngineCore_DP0 pid=4544)                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 87, in determine_available_memory
(EngineCore_DP0 pid=4544)     return self.collective_rpc("determine_available_memory")
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 73, in collective_rpc
(EngineCore_DP0 pid=4544)     return [run_method(self.driver_worker, method, args, kwargs)]
(EngineCore_DP0 pid=4544)             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/utils/__init__.py", line 2975, in run_method
(EngineCore_DP0 pid=4544)     return func(*args, **kwargs)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 122, in decorate_context
(EngineCore_DP0 pid=4544)     return func(*args, **kwargs)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 279, in determine_available_memory
(EngineCore_DP0 pid=4544)     self.model_runner.profile_run()
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 3699, in profile_run
(EngineCore_DP0 pid=4544)     hidden_states, last_hidden_states = self._dummy_run(
(EngineCore_DP0 pid=4544)                                         ^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 122, in decorate_context
(EngineCore_DP0 pid=4544)     return func(*args, **kwargs)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 3452, in _dummy_run
(EngineCore_DP0 pid=4544)     outputs = self.model(
(EngineCore_DP0 pid=4544)               ^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/compilation/cuda_graph.py", line 125, in __call__
(EngineCore_DP0 pid=4544)     return self.runnable(*args, **kwargs)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
(EngineCore_DP0 pid=4544)     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1795, in _call_impl
(EngineCore_DP0 pid=4544)     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/qwen2_vl.py", line 1598, in forward
(EngineCore_DP0 pid=4544)     hidden_states = self.language_model.model(
(EngineCore_DP0 pid=4544)                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 404, in __call__
(EngineCore_DP0 pid=4544)     self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_dynamo/aot_compile.py", line 118, in save_compiled_function
(EngineCore_DP0 pid=4544)     f.write(type(self).serialize(self))
(EngineCore_DP0 pid=4544)             ^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_dynamo/aot_compile.py", line 130, in serialize
(EngineCore_DP0 pid=4544)     type(compiled_fn).serialize_compile_artifacts(compiled_fn),
(EngineCore_DP0 pid=4544)     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/compilation/caching.py", line 87, in serialize_compile_artifacts
(EngineCore_DP0 pid=4544)     state["graph_module"] = GraphPickler.dumps(
(EngineCore_DP0 pid=4544)                             ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/fx/_graph_pickler.py", line 125, in dumps
(EngineCore_DP0 pid=4544)     pickler.dump(obj)
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/compilation/caching.py", line 80, in _graph_reducer_override
(EngineCore_DP0 pid=4544)     return graph_reducer_override(self, obj)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544)   File "/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/fx/_graph_pickler.py", line 103, in reducer_override
(EngineCore_DP0 pid=4544)     assert not isinstance(obj, torch.fx.Node)
(EngineCore_DP0 pid=4544)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=4544) AssertionError

Seems there is some serialization bug on pytorch side while saving the graph module.

Since it's possible to have random bugs during serialization, and we're still using aot compile as dynamo caching mechanism, we should be able to have some fallback behavior (e.g. stop serialization, and fallback to the old path by retracing with dynamo). I will send a PR to implement a fallback mechanism as a short term fix.

In the long term, we should be able to serialize inductor artifacts directly, so there's no need to serialization graph module. Hence in the long term this issue should go away when inductor bundling is merged (e.g. #25205).

📝 History of failing test

The test failure should start since we landed aot compilation in vllm. Formal vllm releases shouldn't be affected because it will rely on torch 2.8/9 and we only turn on aot compile for torch 2.10

CC List.

@zou3519 @ProExpertProg

Metadata

Metadata

Assignees

No one assigned

    Labels

    ci-failureIssue about an unexpected test failure in CI

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions