@@ -20,8 +20,8 @@ ARG PYTHON_VERSION=3.12
2020# glibc version is baked into the distro, and binaries built with one glibc
2121# version are not backwards compatible with OSes that use an earlier version.
2222ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
23- # TODO: Restore to base image after FlashInfer AOT wheel fixed
24- ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel -ubuntu22.04
23+ # Using cuda base image with minimal dependencies necessary for JIT compilation ( FlashInfer, DeepGEMM, EP kernels)
24+ ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-base -ubuntu22.04
2525
2626# By parameterizing the Deadsnakes repository URL, we allow third-party to use
2727# their own mirror. When doing so, we don't benefit from the transparent
@@ -328,6 +328,18 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
328328 && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
329329 && python3 --version && python3 -m pip --version
330330
331+ # Install CUDA development tools and build essentials for runtime JIT compilation
332+ # (FlashInfer, DeepGEMM, EP kernels all require compilation at runtime)
333+ RUN CUDA_VERSION_DASH=$(echo $CUDA_VERSION | cut -d. -f1,2 | tr '.' '-' ) && \
334+ apt-get update -y && \
335+ apt-get install -y --no-install-recommends \
336+ cuda-nvcc-${CUDA_VERSION_DASH} \
337+ cuda-cudart-${CUDA_VERSION_DASH} \
338+ cuda-nvrtc-${CUDA_VERSION_DASH} \
339+ cuda-cuobjdump-${CUDA_VERSION_DASH} \
340+ libcublas-${CUDA_VERSION_DASH} && \
341+ rm -rf /var/lib/apt/lists/*
342+
331343ARG PIP_INDEX_URL UV_INDEX_URL
332344ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
333345ARG PYTORCH_CUDA_INDEX_BASE_URL
0 commit comments