@@ -18,12 +18,15 @@ package nvcdi
1818
1919import (
2020 "fmt"
21+ "path/filepath"
22+ "strings"
2123
2224 "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2325 "github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
2426 "github.com/NVIDIA/go-nvml/pkg/nvml"
2527
2628 "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
29+ "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/cuda"
2730 "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
2831 "github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
2932 "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
@@ -184,18 +187,36 @@ func (l *nvcdilib) resolveMode() (rmode string) {
184187 return ModeNvml
185188}
186189
187- // getCudaVersion returns the CUDA version of the current system.
188- func (l * nvcdilib ) getCudaVersion () (string , error ) {
189- version , err := l .getCudaVersionNvsandboxutils ()
190- if err == nil {
190+ // getDriverVersion returns the driver version of the current system.
191+ func (l * nvcdilib ) getDriverVersion () (string , error ) {
192+ if version , err := l .getDriverVersionNvsandboxutils (); err == nil && version != "" {
191193 return version , err
192194 }
193195
194196 // Fallback to NVML
195- return l .getCudaVersionNvml ()
197+ if version , err := l .getDriverVersionNvml (); err == nil && version != "" {
198+ return version , err
199+ }
200+
201+ // Fallback to getting the version from the libcuda.so suffix.
202+ return l .getDriverVersionLibcudaSo ()
203+ }
204+
205+ func (l * nvcdilib ) getDriverVersionLibcudaSo () (string , error ) {
206+ libCudaPaths , err := cuda .New (
207+ l .driver .Libraries (),
208+ ).Locate (".*.*" )
209+ if err != nil {
210+ return "" , fmt .Errorf ("failed to locate libcuda.so: %v" , err )
211+ }
212+ libCudaPath := libCudaPaths [0 ]
213+
214+ version := strings .TrimPrefix (filepath .Base (libCudaPath ), "libcuda.so." )
215+
216+ return version , nil
196217}
197218
198- func (l * nvcdilib ) getCudaVersionNvml () (string , error ) {
219+ func (l * nvcdilib ) getDriverVersionNvml () (string , error ) {
199220 if hasNVML , reason := l .infolib .HasNvml (); ! hasNVML {
200221 return "" , fmt .Errorf ("nvml not detected: %v" , reason )
201222 }
@@ -219,7 +240,7 @@ func (l *nvcdilib) getCudaVersionNvml() (string, error) {
219240 return version , nil
220241}
221242
222- func (l * nvcdilib ) getCudaVersionNvsandboxutils () (string , error ) {
243+ func (l * nvcdilib ) getDriverVersionNvsandboxutils () (string , error ) {
223244 if l .nvsandboxutilslib == nil {
224245 return "" , fmt .Errorf ("libnvsandboxutils is not available" )
225246 }
0 commit comments