@@ -166,38 +166,24 @@ func filterAutomaticDevices(devices []string) []string {
166166func newAutomaticCDISpecModifier (logger logger.Interface , cfg * config.Config , devices []string ) (oci.SpecModifier , error ) {
167167 logger .Debugf ("Generating in-memory CDI specs for devices %v" , devices )
168168
169- perModeIdentifiers := make (map [string ][]string )
170- perModeDeviceClass := map [string ]string {"auto" : automaticDeviceClass }
171- uniqueModes := []string {"auto" }
172- seen := make (map [string ]bool )
173- for _ , device := range devices {
174- mode , id := getModeIdentifier (device )
175- logger .Debugf ("Mapped %v to %v: %v" , device , mode , id )
176- if ! seen [mode ] {
177- uniqueModes = append (uniqueModes , mode )
178- seen [mode ] = true
179- }
180- if id != "" {
181- perModeIdentifiers [mode ] = append (perModeIdentifiers [mode ], id )
182- }
183- }
169+ cdiModeIdentifiers := cdiModeIdentfiersFromDevices (devices ... )
184170
185- logger .Debugf ("Per-mode identifiers: %v" , perModeIdentifiers )
171+ logger .Debugf ("Per-mode identifiers: %v" , cdiModeIdentifiers )
186172 var modifiers oci.SpecModifiers
187- for _ , mode := range uniqueModes {
173+ for _ , mode := range cdiModeIdentifiers . modes {
188174 cdilib , err := nvcdi .New (
189175 nvcdi .WithLogger (logger ),
190176 nvcdi .WithNVIDIACDIHookPath (cfg .NVIDIACTKConfig .Path ),
191177 nvcdi .WithDriverRoot (cfg .NVIDIAContainerCLIConfig .Root ),
192178 nvcdi .WithVendor (automaticDeviceVendor ),
193- nvcdi .WithClass (perModeDeviceClass [mode ]),
179+ nvcdi .WithClass (cdiModeIdentifiers . deviceClassByMode [mode ]),
194180 nvcdi .WithMode (mode ),
195181 )
196182 if err != nil {
197183 return nil , fmt .Errorf ("failed to construct CDI library for mode %q: %w" , mode , err )
198184 }
199185
200- spec , err := cdilib .GetSpec (perModeIdentifiers [mode ]... )
186+ spec , err := cdilib .GetSpec (cdiModeIdentifiers . idsByMode [mode ]... )
201187 if err != nil {
202188 return nil , fmt .Errorf ("failed to generate CDI spec for mode %q: %w" , mode , err )
203189 }
@@ -216,6 +202,35 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
216202 return modifiers , nil
217203}
218204
205+ type cdiModeIdentifiers struct {
206+ modes []string
207+ idsByMode map [string ][]string
208+ deviceClassByMode map [string ]string
209+ }
210+
211+ func cdiModeIdentfiersFromDevices (devices ... string ) * cdiModeIdentifiers {
212+ perModeIdentifiers := make (map [string ][]string )
213+ perModeDeviceClass := map [string ]string {"auto" : automaticDeviceClass }
214+ uniqueModes := []string {"auto" }
215+ seen := make (map [string ]bool )
216+ for _ , device := range devices {
217+ mode , id := getModeIdentifier (device )
218+ if ! seen [mode ] {
219+ uniqueModes = append (uniqueModes , mode )
220+ seen [mode ] = true
221+ }
222+ if id != "" {
223+ perModeIdentifiers [mode ] = append (perModeIdentifiers [mode ], id )
224+ }
225+ }
226+
227+ return & cdiModeIdentifiers {
228+ modes : uniqueModes ,
229+ idsByMode : perModeIdentifiers ,
230+ deviceClassByMode : perModeDeviceClass ,
231+ }
232+ }
233+
219234func getModeIdentifier (device string ) (string , string ) {
220235 if ! strings .HasPrefix (device , "mode=" ) {
221236 return "auto" , strings .TrimPrefix (device , automaticDevicePrefix )
0 commit comments