Skip to content

Commit ea3370b

Browse files
[ROCm][Bugfix] Patch for the Multi-Modal Processor Test group (#29702)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
1 parent c625d7b commit ea3370b

File tree

4 files changed

+104
-28
lines changed

4 files changed

+104
-28
lines changed

docker/Dockerfile.rocm

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
6565
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
6666
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
6767
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
68+
# Centralized v1 package - copied to both test and final stages
69+
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/vllm/v1 /vllm_v1
6870

6971
# -----------------------
7072
# Test vLLM image
@@ -88,10 +90,22 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
8890

8991
# install development dependencies (for testing)
9092
RUN cd /vllm-workspace \
91-
&& rm -rf vllm \
9293
&& python3 -m pip install -e tests/vllm_test_utils \
9394
&& python3 -m pip install pytest-shard
9495

96+
# enable fast downloads from hf (for testing)
97+
RUN --mount=type=cache,target=/root/.cache/uv \
98+
uv pip install --system hf_transfer
99+
ENV HF_HUB_ENABLE_HF_TRANSFER=1
100+
101+
# Copy in the v1 package
102+
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
103+
104+
# Source code is used in the `python_only_compile.sh` test
105+
# We hide it inside `src/` so that this source code
106+
# will not be imported by other tests
107+
RUN mkdir src && mv vllm src/vllm
108+
95109
# -----------------------
96110
# Final vLLM image
97111
FROM base AS final
@@ -116,6 +130,9 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
116130
&& pip uninstall -y vllm \
117131
&& uv pip install --system *.whl
118132

133+
# Copy in the v1 package
134+
COPY --from=export_vllm /vllm_v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
135+
119136
ARG COMMON_WORKDIR
120137

121138
# Copy over the benchmark scripts as well

docker/Dockerfile.rocm_base

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ ARG PYTORCH_BRANCH="1c57644d"
55
ARG PYTORCH_VISION_BRANCH="v0.23.0"
66
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
77
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
8+
ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
9+
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
810
ARG FA_BRANCH="0e60e394"
911
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
1012
ARG AITER_BRANCH="59bd8ff2"
@@ -23,6 +25,7 @@ ENV AITER_ROCM_ARCH=gfx942;gfx950
2325
ENV HSA_NO_SCRATCH_RECLAIM=1
2426

2527
ARG PYTHON_VERSION=3.12
28+
ENV PYTHON_VERSION=${PYTHON_VERSION}
2629

2730
RUN mkdir -p /app
2831
WORKDIR /app
@@ -45,6 +48,7 @@ RUN apt-get update -y \
4548
&& python3 --version && python3 -m pip --version
4649

4750
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
51+
RUN apt-get update && apt-get install -y libjpeg-dev libsox-dev libsox-fmt-all sox && rm -rf /var/lib/apt/lists/*
4852

4953
FROM base AS build_triton
5054
ARG TRITON_BRANCH
@@ -66,20 +70,30 @@ RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
6670
FROM base AS build_pytorch
6771
ARG PYTORCH_BRANCH
6872
ARG PYTORCH_VISION_BRANCH
73+
ARG PYTORCH_AUDIO_BRANCH
6974
ARG PYTORCH_REPO
7075
ARG PYTORCH_VISION_REPO
76+
ARG PYTORCH_AUDIO_REPO
77+
7178
RUN git clone ${PYTORCH_REPO} pytorch
72-
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
73-
pip install -r requirements.txt && git submodule update --init --recursive \
79+
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \
80+
&& pip install -r requirements.txt && git submodule update --init --recursive \
7481
&& python3 tools/amd_build/build_amd.py \
7582
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
7683
&& pip install dist/*.whl
7784
RUN git clone ${PYTORCH_VISION_REPO} vision
7885
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
7986
&& python3 setup.py bdist_wheel --dist-dir=dist \
8087
&& pip install dist/*.whl
88+
RUN git clone ${PYTORCH_AUDIO_REPO} audio
89+
RUN cd audio && git checkout ${PYTORCH_AUDIO_BRANCH} \
90+
&& git submodule update --init --recursive \
91+
&& pip install -r requirements.txt \
92+
&& python3 setup.py bdist_wheel --dist-dir=dist \
93+
&& pip install dist/*.whl
8194
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
82-
&& cp /app/vision/dist/*.whl /app/install
95+
&& cp /app/vision/dist/*.whl /app/install \
96+
&& cp /app/audio/dist/*.whl /app/install
8397

8498
FROM base AS build_fa
8599
ARG FA_BRANCH
@@ -130,6 +144,8 @@ ARG PYTORCH_BRANCH
130144
ARG PYTORCH_VISION_BRANCH
131145
ARG PYTORCH_REPO
132146
ARG PYTORCH_VISION_REPO
147+
ARG PYTORCH_AUDIO_BRANCH
148+
ARG PYTORCH_AUDIO_REPO
133149
ARG FA_BRANCH
134150
ARG FA_REPO
135151
ARG AITER_BRANCH
@@ -141,7 +157,9 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
141157
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
142158
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
143159
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
160+
&& echo "PYTORCH_AUDIO_BRANCH: ${PYTORCH_AUDIO_BRANCH}" >> /app/versions.txt \
161+
&& echo "PYTORCH_AUDIO_REPO: ${PYTORCH_AUDIO_REPO}" >> /app/versions.txt \
144162
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
145163
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
146164
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
147-
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
165+
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt

requirements/rocm-test.txt

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,85 @@
11
# Common dependencies
22
-r common.txt
3+
4+
# Test infrastructure
35
tblib==3.1.0
4-
bm25s==0.2.13
5-
pystemmer==3.0.0
6+
pytest==8.3.5
7+
pytest-asyncio==0.24.0
8+
pytest-timeout==2.3.1
9+
pytest-cov==6.3.0
10+
pytest-forked==1.6.0
11+
pytest-rerunfailures==14.0
12+
pytest-shard==0.1.2
13+
14+
# Async/HTTP dependencies
15+
anyio==4.6.2.post1
16+
# via httpx, starlette
17+
aiohttp==3.13.0
18+
# via gpt-oss
19+
httpx==0.27.2
20+
# HTTP testing
621

7-
# Entrypoints test
8-
# librosa==0.10.2.post1 # required by audio tests in entrypoints/openai
22+
# Audio processing dependencies
923
audioread==3.0.1
24+
# via librosa
1025
cffi==1.17.1
26+
# via soundfile
1127
decorator==5.2.1
28+
# via librosa
1229
lazy-loader==0.4
30+
# via librosa
1331
platformdirs==4.3.6
32+
# via pooch
1433
pooch==1.8.2
15-
#pycparse==2.22
34+
# via librosa
1635
soundfile==0.13.1
36+
# via librosa
1737
soxr==0.5.0.post1
38+
# via librosa
1839
librosa==0.10.2.post1
1940

20-
# Entrypoints test
21-
#vllm[video] # required by entrypoints/openai/test_video.py
22-
decord==0.6.0
23-
24-
# Entrypoints test
25-
#sentence-transformers # required by entrypoints/openai/test_score.py
26-
sentence-transformers==3.4.1
27-
28-
# Basic Models Test
29-
matplotlib==3.10.3
41+
# Retrieval and search
42+
bm25s==0.2.13
43+
# via mteb
44+
pystemmer==3.0.0
45+
# via mteb
3046

31-
# Multi-Modal Models Test (Extended) 3
47+
# Multi-modal processing
3248
blobfile==3.0.0
49+
# Multi-Modal Models Test
50+
decord==0.6.0
51+
# video processing, required by entrypoints/openai/test_video.py
3352

34-
# Required for openai schema test.
53+
# OpenAI compatibility and testing
54+
gpt-oss==0.0.8
55+
# OpenAI compatibility tests
3556
schemathesis==3.39.15
57+
# OpenAI schema test
3658

37-
# Required for mteb test
38-
mteb[bm25s]>=1.38.11, <2
39-
40-
# Required for eval tests
59+
# Evaluation and benchmarking
4160
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
4261

43-
# Required for multiprocessed tests that use spawn method
62+
# Required for multiprocessed tests that use spawn method, Datasets and Evaluate Test
4463
multiprocess==0.70.16
4564

4665
# Plugins test
4766
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
4867
torchgeo==0.7.0
68+
# via terratorch
69+
# MTEB Benchmark Test
70+
mteb==2.1.2
71+
72+
# Data processing
73+
xgrammar @ git+https://github.com/mlc-ai/xgrammar.git@eafd4db51b78acc64b3f0764ef27dfd206c28628
74+
# Test async scheduling
75+
76+
# Utilities
77+
num2words==0.5.14
78+
# via lm-eval
79+
pqdm==0.2.0
80+
# via lm-eval
4981

5082
# Required for suffix decoding test
5183
arctic-inference == 0.1.1
84+
# Required for Nemotron test
85+
open-clip-torch==2.32.0

tests/models/multimodal/processing/test_tensor_schema.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
3131
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
3232
from vllm.multimodal.utils import group_mm_kwargs_by_modality
33+
from vllm.platforms import current_platform
3334
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
3435
from vllm.utils.collection_utils import is_list_of
3536
from vllm.utils.torch_utils import set_default_torch_dtype
@@ -176,6 +177,12 @@ def test_model_tensor_schema(model_id: str):
176177
exist_overrides=model_info.hf_overrides,
177178
)
178179

180+
# ROCm: Detect if model uses AWQ quantization and set appropriate dtype
181+
if "awq" in model_id.lower() and current_platform.is_rocm():
182+
dtype = "float16"
183+
else:
184+
dtype = model_info.dtype
185+
179186
model_config = ModelConfig(
180187
model_id,
181188
tokenizer=model_info.tokenizer or model_id,
@@ -187,7 +194,7 @@ def test_model_tensor_schema(model_id: str):
187194
enable_prompt_embeds=model_info.require_embed_inputs,
188195
enable_mm_embeds=model_info.require_embed_inputs,
189196
enforce_eager=model_info.enforce_eager,
190-
dtype=model_info.dtype,
197+
dtype=dtype,
191198
)
192199

193200
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)

0 commit comments

Comments
 (0)