@@ -41,31 +41,31 @@ public static SystemInfo Get()
4141
4242 return new SystemInfo ( platform , GetCudaMajorVersion ( ) , GetVulkanVersion ( ) ) ;
4343 }
44-
44+
4545 #region Vulkan version
4646 private static string ? GetVulkanVersion ( )
4747 {
4848 // Get Vulkan Summary
4949 string ? vulkanSummary = GetVulkanSummary ( ) ;
50-
50+
5151 // If we have a Vulkan summary
5252 if ( vulkanSummary != null )
5353 {
5454 // Extract Vulkan version from summary
5555 string ? vulkanVersion = ExtractVulkanVersionFromSummary ( vulkanSummary ) ;
56-
56+
5757 // If we have a Vulkan version
5858 if ( vulkanVersion != null )
5959 {
6060 // Return the Vulkan version
6161 return vulkanVersion ;
6262 }
6363 }
64-
64+
6565 // Return null if we failed to get the Vulkan version
6666 return null ;
6767 }
68-
68+
6969 private static string ? GetVulkanSummary ( )
7070 {
7171 // Note: on Linux, this requires `vulkan-tools` to be installed. (`sudo apt install vulkan-tools`)
@@ -102,19 +102,19 @@ public static SystemInfo Get()
102102 // We have three ways of parsing the Vulkan version from the summary (output is a different between Windows and Linux)
103103 // For now, I have decided to go with the full version number, and leave it up to the user to parse it further if needed
104104 // I have left the other patterns in, in case we need them in the future
105-
105+
106106 // Output on linux : 4206847 (1.3.255)
107107 // Output on windows : 1.3.255
108108 string pattern = @"apiVersion\s*=\s*([^\r\n]+)" ;
109-
109+
110110 // Output on linux : 4206847
111111 // Output on windows : 1.3.255
112112 //string pattern = @"apiVersion\s*=\s*([\d\.]+)";
113-
113+
114114 // Output on linux : 1.3.255
115115 // Output on windows : 1.3.255
116116 //string pattern = @"apiVersion\s*=\s*(?:\d+\s*)?(?:\(\s*)?([\d]+\.[\d]+\.[\d]+)(?:\s*\))?";
117-
117+
118118 // Create a Regex object to match the pattern
119119 Regex regex = new Regex ( pattern ) ;
120120
@@ -158,24 +158,30 @@ private static int GetCudaMajorVersion()
158158 }
159159 else if ( RuntimeInformation . IsOSPlatform ( OSPlatform . Linux ) )
160160 {
161+ string ? env_version = Environment . GetEnvironmentVariable ( "CUDA_VERSION" ) ;
162+ if ( env_version is not null )
163+ {
164+ return ExtractMajorVersion ( ref env_version ) ;
165+ }
166+
161167 // List of default cuda paths
162168 string [ ] defaultCudaPaths =
163169 [
164170 "/usr/local/bin/cuda" ,
165171 "/usr/local/cuda" ,
166172 ] ;
167-
173+
168174 // Loop through every default path to find the version
169175 foreach ( var path in defaultCudaPaths )
170176 {
171177 // Attempt to get the version from the path
172178 version = GetCudaVersionFromPath ( path ) ;
173-
179+
174180 // If a CUDA version is found, break the loop
175181 if ( ! string . IsNullOrEmpty ( version ) )
176182 break ;
177183 }
178-
184+
179185 if ( string . IsNullOrEmpty ( version ) )
180186 {
181187 cudaPath = Environment . GetEnvironmentVariable ( "LD_LIBRARY_PATH" ) ;
@@ -197,6 +203,11 @@ private static int GetCudaMajorVersion()
197203 if ( string . IsNullOrEmpty ( version ) )
198204 return - 1 ;
199205
206+ return ExtractMajorVersion ( ref version ) ;
207+ }
208+
209+ private static int ExtractMajorVersion ( ref string version )
210+ {
200211 version = version . Split ( '.' ) [ 0 ] ;
201212 if ( int . TryParse ( version , out var majorVersion ) )
202213 return majorVersion ;
0 commit comments