Skip to content

Commit 9c5bf34

Browse files
authored
fix: multi-cuda version skew (ollama#12318)
Ensure that in a version skewed multi-cuda setup we use the lowest version for all GPUs
1 parent 564b558 commit 9c5bf34

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

discover/cuda_common.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
1717
var CudaTegra string = os.Getenv("JETSON_JETPACK")
1818

19-
func cudaVariant(gpuInfo CudaGPUInfo) string {
19+
func cudaVariant(gpuInfos []CudaGPUInfo) string {
2020
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
2121
if CudaTegra != "" {
2222
ver := strings.Split(CudaTegra, ".")
@@ -45,20 +45,19 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
4545
}
4646
}
4747

48-
// Check GPU compute capability FIRST
49-
isOldGPU := gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5)
50-
if isOldGPU {
51-
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
52-
return "v12"
48+
// Check GPU compute capability FIRST, lowest common denominator if multi-gpu
49+
for _, gpuInfo := range gpuInfos {
50+
if gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) {
51+
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
52+
return "v12"
53+
}
5354
}
5455

5556
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
56-
if gpuInfo.DriverMajor < 13 {
57+
if len(gpuInfos) > 0 && gpuInfos[0].DriverMajor < 13 {
5758
// The detected driver is older than 580 (Aug 2025)
5859
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
59-
if !isOldGPU {
60-
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
61-
}
60+
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfos[0].DriverMajor, gpuInfos[0].DriverMinor))
6261
return "v12"
6362
}
6463
return "v13"

discover/gpu.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -284,18 +284,8 @@ func GetGPUInfo() GpuInfoList {
284284
gpuInfo.MinimumMemory = cudaMinimumMemory
285285
gpuInfo.DriverMajor = driverMajor
286286
gpuInfo.DriverMinor = driverMinor
287-
variant := cudaVariant(gpuInfo)
288-
289-
// Start with our bundled libraries
290-
if variant != "" {
291-
variantPath := filepath.Join(LibOllamaPath, "cuda_"+variant)
292-
if _, err := os.Stat(variantPath); err == nil {
293-
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
294-
gpuInfo.DependencyPath = append([]string{variantPath}, gpuInfo.DependencyPath...)
295-
}
296-
}
287+
297288
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
298-
gpuInfo.Variant = variant
299289

300290
if int(memInfo.major) < cudaComputeMajorMin || (int(memInfo.major) == cudaComputeMajorMin && int(memInfo.minor) < cudaComputeMinorMin) {
301291
unsupportedGPUs = append(unsupportedGPUs,
@@ -333,6 +323,24 @@ func GetGPUInfo() GpuInfoList {
333323
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
334324
cudaGPUs = append(cudaGPUs, gpuInfo)
335325
}
326+
// Second pass on NVIDIA GPUs to set lowest common denominator variant and DependencyPaths
327+
variant := cudaVariant(cudaGPUs)
328+
var variantPath string
329+
// Start with our bundled libraries
330+
if variant != "" {
331+
variantPath = filepath.Join(LibOllamaPath, "cuda_"+variant)
332+
if _, err := os.Stat(variantPath); err != nil {
333+
variantPath = ""
334+
}
335+
}
336+
337+
for i := range cudaGPUs {
338+
cudaGPUs[i].Variant = variant
339+
if variantPath != "" {
340+
// Put the variant directory first in the search path to avoid runtime linking to the wrong library
341+
cudaGPUs[i].DependencyPath = append([]string{variantPath}, cudaGPUs[i].DependencyPath...)
342+
}
343+
}
336344
}
337345

338346
// Intel

0 commit comments

Comments
 (0)