@@ -18,6 +18,7 @@ package modifier
1818
1919import (
2020 "fmt"
21+ "strings"
2122
2223 "tags.cncf.io/container-device-interface/pkg/parser"
2324
@@ -27,17 +28,27 @@ import (
2728 "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
2829 "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2930 "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
30- "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
31+ )
32+
33+ const (
34+ automaticDeviceVendor = "runtime.nvidia.com"
35+ automaticDeviceClass = "gpu"
36+ automaticDeviceKind = automaticDeviceVendor + "/" + automaticDeviceClass
37+ automaticDevicePrefix = automaticDeviceKind + "="
3138)
3239
3340// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
3441// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
3542// used to select the devices to include.
36- func NewCDIModifier (logger logger.Interface , cfg * config.Config , image image.CUDA ) (oci.SpecModifier , error ) {
43+ func NewCDIModifier (logger logger.Interface , cfg * config.Config , image image.CUDA , isJitCDI bool ) (oci.SpecModifier , error ) {
44+ defaultKind := cfg .NVIDIAContainerRuntimeConfig .Modes .CDI .DefaultKind
45+ if isJitCDI {
46+ defaultKind = automaticDeviceKind
47+ }
3748 deviceRequestor := newCDIDeviceRequestor (
3849 logger ,
3950 image ,
40- cfg . NVIDIAContainerRuntimeConfig . Modes . CDI . DefaultKind ,
51+ defaultKind ,
4152 )
4253 devices := deviceRequestor .DeviceRequests ()
4354 if len (devices ) == 0 {
@@ -107,50 +118,46 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string {
107118func filterAutomaticDevices (devices []string ) []string {
108119 var automatic []string
109120 for _ , device := range devices {
110- vendor , class , _ := parser .ParseDevice (device )
111- if vendor == "runtime.nvidia.com" && class == "gpu" {
112- automatic = append (automatic , device )
121+ if ! strings .HasPrefix (device , automaticDevicePrefix ) {
122+ continue
113123 }
124+ automatic = append (automatic , device )
114125 }
115126 return automatic
116127}
117128
118129func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
119130 logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
120- spec , err := generateAutomaticCDISpec (logger , cfg , devices )
121- if err != nil {
122- return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
123- }
124- cdiDeviceRequestor , err := cdi .New (
125- cdi .WithLogger (logger ),
126- cdi .WithSpec (spec .Raw ()),
127- )
128- if err != nil {
129- return nil , fmt .Errorf ("failed to construct CDI modifier: %w" , err )
130- }
131131
132- return cdiDeviceRequestor , nil
133- }
132+ var identifiers []string
133+ for _ , device := range devices {
134+ identifiers = append (identifiers , strings .TrimPrefix (device , automaticDevicePrefix ))
135+ }
134136
135- func generateAutomaticCDISpec (logger logger.Interface , cfg * config.Config , devices []string ) (spec.Interface , error ) {
136137 cdilib , err := nvcdi .New (
137138 nvcdi .WithLogger (logger ),
138139 nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
139140 nvcdi .WithDriverRoot (cfg .NVIDIAContainerCLIConfig .Root ),
140- nvcdi .WithVendor ("runtime.nvidia.com" ),
141- nvcdi .WithClass ("gpu" ),
141+ nvcdi .WithVendor (automaticDeviceVendor ),
142+ nvcdi .WithClass (automaticDeviceClass ),
142143 )
143144 if err != nil {
144145 return nil , fmt .Errorf ("failed to construct CDI library: %w" , err )
145146 }
146147
147- var identifiers []string
148- for _ , device := range devices {
149- _ , _ , id := parser .ParseDevice (device )
150- identifiers = append (identifiers , id )
148+ spec , err := cdilib .GetSpec (identifiers ... )
149+ if err != nil {
150+ return nil , fmt .Errorf ("failed to generate CDI spec: %w" , err )
151+ }
152+ cdiDeviceRequestor , err := cdi .New (
153+ cdi .WithLogger (logger ),
154+ cdi .WithSpec (spec .Raw ()),
155+ )
156+ if err != nil {
157+ return nil , fmt .Errorf ("failed to construct CDI modifier: %w" , err )
151158 }
152159
153- return cdilib . GetSpec ( identifiers ... )
160+ return cdiDeviceRequestor , nil
154161}
155162
156163type deduplicatedDeviceRequestor struct {
0 commit comments