Skip to content

Commit a33c7d8

Browse files
authored
Merge pull request #1431 from elezar/fix-duplicate-auto-mode
Fix duplicate CDI spec generation in jit-cdi mode
2 parents 39ca52d + 6b95511 commit a33c7d8

File tree

4 files changed

+128
-20
lines changed

4 files changed

+128
-20
lines changed

internal/modifier/cdi.go

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -166,39 +166,25 @@ func filterAutomaticDevices(devices []string) []string {
166166
func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) {
167167
logger.Debugf("Generating in-memory CDI specs for devices %v", devices)
168168

169-
perModeIdentifiers := make(map[string][]string)
170-
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
171-
uniqueModes := []string{"auto"}
172-
seen := make(map[string]bool)
173-
for _, device := range devices {
174-
mode, id := getModeIdentifier(device)
175-
logger.Debugf("Mapped %v to %v: %v", device, mode, id)
176-
if !seen[mode] {
177-
uniqueModes = append(uniqueModes, mode)
178-
seen[mode] = true
179-
}
180-
if id != "" {
181-
perModeIdentifiers[mode] = append(perModeIdentifiers[mode], id)
182-
}
183-
}
169+
cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...)
184170

185-
logger.Debugf("Per-mode identifiers: %v", perModeIdentifiers)
171+
logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers)
186172
var modifiers oci.SpecModifiers
187-
for _, mode := range uniqueModes {
173+
for _, mode := range cdiModeIdentifiers.modes {
188174
cdilib, err := nvcdi.New(
189175
nvcdi.WithLogger(logger),
190176
nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path),
191177
nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root),
192178
nvcdi.WithVendor(automaticDeviceVendor),
193-
nvcdi.WithClass(perModeDeviceClass[mode]),
179+
nvcdi.WithClass(cdiModeIdentifiers.deviceClassByMode[mode]),
194180
nvcdi.WithMode(mode),
195181
nvcdi.WithFeatureFlags(cfg.NVIDIAContainerRuntimeConfig.Modes.JitCDI.NVCDIFeatureFlags...),
196182
)
197183
if err != nil {
198184
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)
199185
}
200186

201-
spec, err := cdilib.GetSpec(perModeIdentifiers[mode]...)
187+
spec, err := cdilib.GetSpec(cdiModeIdentifiers.idsByMode[mode]...)
202188
if err != nil {
203189
return nil, fmt.Errorf("failed to generate CDI spec for mode %q: %w", mode, err)
204190
}
@@ -217,6 +203,35 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
217203
return modifiers, nil
218204
}
219205

206+
type cdiModeIdentifiers struct {
207+
modes []string
208+
idsByMode map[string][]string
209+
deviceClassByMode map[string]string
210+
}
211+
212+
func cdiModeIdentfiersFromDevices(devices ...string) *cdiModeIdentifiers {
213+
perModeIdentifiers := make(map[string][]string)
214+
perModeDeviceClass := map[string]string{"auto": automaticDeviceClass}
215+
var uniqueModes []string
216+
seen := make(map[string]bool)
217+
for _, device := range devices {
218+
mode, id := getModeIdentifier(device)
219+
if !seen[mode] {
220+
uniqueModes = append(uniqueModes, mode)
221+
seen[mode] = true
222+
}
223+
if id != "" {
224+
perModeIdentifiers[mode] = append(perModeIdentifiers[mode], id)
225+
}
226+
}
227+
228+
return &cdiModeIdentifiers{
229+
modes: uniqueModes,
230+
idsByMode: perModeIdentifiers,
231+
deviceClassByMode: perModeDeviceClass,
232+
}
233+
}
234+
220235
func getModeIdentifier(device string) (string, string) {
221236
if !strings.HasPrefix(device, "mode=") {
222237
return "auto", strings.TrimPrefix(device, automaticDevicePrefix)

internal/modifier/cdi_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,86 @@ func TestDeviceRequests(t *testing.T) {
170170
})
171171
}
172172
}
173+
174+
func Test_cdiModeIdentfiersFromDevices(t *testing.T) {
175+
testCases := []struct {
176+
description string
177+
devices []string
178+
expected *cdiModeIdentifiers
179+
}{
180+
{
181+
description: "empty device list",
182+
devices: []string{},
183+
expected: &cdiModeIdentifiers{
184+
modes: nil,
185+
idsByMode: map[string][]string{},
186+
deviceClassByMode: map[string]string{"auto": "gpu"},
187+
},
188+
},
189+
{
190+
description: "single automatic device",
191+
devices: []string{"0"},
192+
expected: &cdiModeIdentifiers{
193+
modes: []string{"auto"},
194+
idsByMode: map[string][]string{"auto": {"0"}},
195+
deviceClassByMode: map[string]string{"auto": "gpu"},
196+
},
197+
},
198+
{
199+
description: "multiple automatic devices",
200+
devices: []string{"0", "1"},
201+
expected: &cdiModeIdentifiers{
202+
modes: []string{"auto"},
203+
idsByMode: map[string][]string{"auto": {"0", "1"}},
204+
deviceClassByMode: map[string]string{"auto": "gpu"},
205+
},
206+
},
207+
{
208+
description: "device with explicit mode",
209+
devices: []string{"mode=gds,id=foo"},
210+
expected: &cdiModeIdentifiers{
211+
modes: []string{"gds"},
212+
idsByMode: map[string][]string{"gds": {"foo"}},
213+
deviceClassByMode: map[string]string{"auto": "gpu"},
214+
},
215+
},
216+
{
217+
description: "mixed auto and explicit",
218+
devices: []string{"0", "mode=gds,id=foo", "mode=gdrcopy,id=bar"},
219+
expected: &cdiModeIdentifiers{
220+
modes: []string{"auto", "gds", "gdrcopy"},
221+
idsByMode: map[string][]string{
222+
"auto": {"0"},
223+
"gds": {"foo"},
224+
"gdrcopy": {"bar"},
225+
},
226+
deviceClassByMode: map[string]string{"auto": "gpu"},
227+
},
228+
},
229+
{
230+
description: "device with only mode, no id",
231+
devices: []string{"mode=nvswitch"},
232+
expected: &cdiModeIdentifiers{
233+
modes: []string{"nvswitch"},
234+
idsByMode: map[string][]string{},
235+
deviceClassByMode: map[string]string{"auto": "gpu"},
236+
},
237+
},
238+
{
239+
description: "duplicate modes",
240+
devices: []string{"mode=gds,id=x", "mode=gds,id=y", "mode=gds"},
241+
expected: &cdiModeIdentifiers{
242+
modes: []string{"gds"},
243+
idsByMode: map[string][]string{"gds": {"x", "y"}},
244+
deviceClassByMode: map[string]string{"auto": "gpu"},
245+
},
246+
},
247+
}
248+
249+
for _, tc := range testCases {
250+
t.Run(tc.description, func(t *testing.T) {
251+
result := cdiModeIdentfiersFromDevices(tc.devices...)
252+
require.EqualValues(t, tc.expected, result)
253+
})
254+
}
255+
}

tests/e2e/nvidia-container-toolkit_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,16 @@ var _ = Describe("docker", Ordered, ContinueOnFailure, func() {
198198
Expect(err).ToNot(HaveOccurred())
199199
Expect(ldconfigOut).To(ContainSubstring("/usr/local/cuda-12.9/compat/"))
200200
})
201+
202+
It("should create a single ld.so.conf.d config file", func(ctx context.Context) {
203+
lsout, _, err := runner.Run("docker run --rm -i -e NVIDIA_DISABLE_REQUIRE=true --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all nvcr.io/nvidia/cuda:12.9.0-base-ubi8 bash -c \"ls -l /etc/ld.so.conf.d/00-compat-*.conf\"")
204+
Expect(err).ToNot(HaveOccurred())
205+
Expect(lsout).To(WithTransform(
206+
func(s string) []string {
207+
return strings.Split(strings.TrimSpace(s), "\n")
208+
}, HaveLen(1),
209+
))
210+
})
201211
})
202212

203213
When("Disabling device node creation", Ordered, func() {

0 commit comments

Comments
 (0)