Skip to content

Commit 208896d

Browse files
authored
Merge pull request #1130 from ArangoGutierrez/fix/1049
BUGFIX: modifier: respect GPU volume-mount device requests
2 parents 8149be0 + 82b6289 commit 208896d

File tree

7 files changed

+153
-54
lines changed

7 files changed

+153
-54
lines changed

internal/config/image/cuda_image.go

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package image
1919
import (
2020
"fmt"
2121
"path/filepath"
22+
"slices"
2223
"strconv"
2324
"strings"
2425

@@ -143,8 +144,8 @@ func (i CUDA) HasDisableRequire() bool {
143144
return false
144145
}
145146

146-
// DevicesFromEnvvars returns the devices requested by the image through environment variables
147-
func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
147+
// devicesFromEnvvars returns the devices requested by the image through environment variables
148+
func (i CUDA) devicesFromEnvvars(envVars ...string) []string {
148149
// We concantenate all the devices from the specified env.
149150
var isSet bool
150151
var devices []string
@@ -165,15 +166,15 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
165166

166167
// Environment variable unset with legacy image: default to "all".
167168
if !isSet && len(devices) == 0 && i.IsLegacy() {
168-
return NewVisibleDevices("all")
169+
devices = []string{"all"}
169170
}
170171

171172
// Environment variable unset or empty or "void": return nil
172173
if len(devices) == 0 || requested["void"] {
173-
return NewVisibleDevices("void")
174+
devices = []string{"void"}
174175
}
175176

176-
return NewVisibleDevices(devices...)
177+
return NewVisibleDevices(devices...).List()
177178
}
178179

179180
// GetDriverCapabilities returns the requested driver capabilities.
@@ -232,6 +233,22 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
232233
return hasCDIdevice
233234
}
234235

236+
// visibleEnvVars returns the environment variables that are used to determine device visibility.
237+
// It returns the preferred environment variables that are set, or NVIDIA_VISIBLE_DEVICES if none are set.
238+
func (i CUDA) visibleEnvVars() []string {
239+
var envVars []string
240+
for _, envVar := range i.preferredVisibleDeviceEnvVars {
241+
if !i.HasEnvvar(envVar) {
242+
continue
243+
}
244+
envVars = append(envVars, envVar)
245+
}
246+
if len(envVars) > 0 {
247+
return envVars
248+
}
249+
return []string{EnvVarNvidiaVisibleDevices}
250+
}
251+
235252
// VisibleDevices returns a list of devices requested in the container image.
236253
// If volume mount requests are enabled these are returned if requested,
237254
// otherwise device requests through environment variables are considered.
@@ -253,7 +270,7 @@ func (i CUDA) VisibleDevices() []string {
253270
}
254271

255272
// Get the Fallback to reading from the environment variable if privileges are correct
256-
envVarDeviceRequests := i.VisibleDevicesFromEnvVar()
273+
envVarDeviceRequests := i.visibleDevicesFromEnvVar()
257274
if len(envVarDeviceRequests) == 0 {
258275
return nil
259276
}
@@ -265,7 +282,10 @@ func (i CUDA) VisibleDevices() []string {
265282
}
266283

267284
// We log a warning if we are ignoring the environment variable requests.
268-
i.logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES in unprivileged container")
285+
envVars := i.visibleEnvVars()
286+
if len(envVars) > 0 {
287+
i.logger.Warningf("Ignoring devices requested by environment variable(s) in unprivileged container: %v", envVars)
288+
}
269289

270290
return nil
271291
}
@@ -281,31 +301,34 @@ func (i CUDA) cdiDeviceRequestsFromAnnotations() []string {
281301
return nil
282302
}
283303

284-
var devices []string
285-
for key, value := range i.annotations {
304+
var annotationKeys []string
305+
for key := range i.annotations {
286306
for _, prefix := range i.annotationsPrefixes {
287307
if strings.HasPrefix(key, prefix) {
288-
devices = append(devices, strings.Split(value, ",")...)
308+
annotationKeys = append(annotationKeys, key)
289309
// There is no need to check additional prefixes since we
290310
// typically deduplicate devices in any case.
291311
break
292312
}
293313
}
294314
}
315+
// We sort the annotationKeys for consistent results.
316+
slices.Sort(annotationKeys)
317+
318+
var devices []string
319+
for _, key := range annotationKeys {
320+
devices = append(devices, strings.Split(i.annotations[key], ",")...)
321+
}
295322
return devices
296323
}
297324

298-
// VisibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
325+
// visibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
299326
// If any of the preferredVisibleDeviceEnvVars are present in the image, they
300327
// are used to determine the visible devices. If this is not the case, the
301328
// NVIDIA_VISIBLE_DEVICES environment variable is used.
302-
func (i CUDA) VisibleDevicesFromEnvVar() []string {
303-
for _, envVar := range i.preferredVisibleDeviceEnvVars {
304-
if i.HasEnvvar(envVar) {
305-
return i.DevicesFromEnvvars(i.preferredVisibleDeviceEnvVars...).List()
306-
}
307-
}
308-
return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List()
329+
func (i CUDA) visibleDevicesFromEnvVar() []string {
330+
envVars := i.visibleEnvVars()
331+
return i.devicesFromEnvvars(envVars...)
309332
}
310333

311334
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
@@ -391,7 +414,7 @@ func (m cdiDeviceMountRequest) qualifiedName() (string, error) {
391414

392415
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
393416
func (i CUDA) ImexChannelsFromEnvVar() []string {
394-
imexChannels := i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List()
417+
imexChannels := i.devicesFromEnvvars(EnvVarNvidiaImexChannels)
395418
if len(imexChannels) == 1 && imexChannels[0] == "all" {
396419
return nil
397420
}

internal/config/image/cuda_image_test.go

Lines changed: 88 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
429429
)
430430

431431
require.NoError(t, err)
432-
devices := image.VisibleDevicesFromEnvVar()
432+
devices := image.visibleDevicesFromEnvVar()
433433
require.EqualValues(t, tc.expectedDevices, devices)
434434
})
435435
}
@@ -508,13 +508,15 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
508508

509509
func TestVisibleDevices(t *testing.T) {
510510
var tests = []struct {
511-
description string
512-
mountDevices []specs.Mount
513-
envvarDevices string
514-
privileged bool
515-
acceptUnprivileged bool
516-
acceptMounts bool
517-
expectedDevices []string
511+
description string
512+
mountDevices []specs.Mount
513+
envvarDevices string
514+
privileged bool
515+
acceptUnprivileged bool
516+
acceptMounts bool
517+
preferredVisibleDeviceEnvVars []string
518+
env map[string]string
519+
expectedDevices []string
518520
}{
519521
{
520522
description: "Mount devices, unprivileged, no accept unprivileged",
@@ -597,20 +599,92 @@ func TestVisibleDevices(t *testing.T) {
597599
acceptMounts: false,
598600
expectedDevices: nil,
599601
},
602+
// New test cases for visibleEnvVars functionality
603+
{
604+
description: "preferred env var set and present in env, privileged",
605+
mountDevices: nil,
606+
envvarDevices: "",
607+
privileged: true,
608+
acceptUnprivileged: false,
609+
acceptMounts: true,
610+
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
611+
env: map[string]string{
612+
"DOCKER_RESOURCE_GPUS": "GPU-12345",
613+
},
614+
expectedDevices: []string{"GPU-12345"},
615+
},
616+
{
617+
description: "preferred env var set and present in env, unprivileged but accepted",
618+
mountDevices: nil,
619+
envvarDevices: "",
620+
privileged: false,
621+
acceptUnprivileged: true,
622+
acceptMounts: true,
623+
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
624+
env: map[string]string{
625+
"DOCKER_RESOURCE_GPUS": "GPU-12345",
626+
},
627+
expectedDevices: []string{"GPU-12345"},
628+
},
629+
{
630+
description: "preferred env var set and present in env, unprivileged and not accepted",
631+
mountDevices: nil,
632+
envvarDevices: "",
633+
privileged: false,
634+
acceptUnprivileged: false,
635+
acceptMounts: true,
636+
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
637+
env: map[string]string{
638+
"DOCKER_RESOURCE_GPUS": "GPU-12345",
639+
},
640+
expectedDevices: nil,
641+
},
642+
{
643+
description: "multiple preferred env vars, both present, privileged",
644+
mountDevices: nil,
645+
envvarDevices: "",
646+
privileged: true,
647+
acceptUnprivileged: false,
648+
acceptMounts: true,
649+
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
650+
env: map[string]string{
651+
"DOCKER_RESOURCE_GPUS": "GPU-12345",
652+
"DOCKER_RESOURCE_GPUS_ADDITIONAL": "GPU-67890",
653+
},
654+
expectedDevices: []string{"GPU-12345", "GPU-67890"},
655+
},
656+
{
657+
description: "preferred env var not present, fallback to NVIDIA_VISIBLE_DEVICES, privileged",
658+
mountDevices: nil,
659+
envvarDevices: "GPU-12345",
660+
privileged: true,
661+
acceptUnprivileged: false,
662+
acceptMounts: true,
663+
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
664+
env: map[string]string{
665+
EnvVarNvidiaVisibleDevices: "GPU-12345",
666+
},
667+
expectedDevices: []string{"GPU-12345"},
668+
},
600669
}
601670
for _, tc := range tests {
602671
t.Run(tc.description, func(t *testing.T) {
603-
// Wrap the call to getDevices() in a closure.
672+
// Create env map with both NVIDIA_VISIBLE_DEVICES and any additional env vars
673+
env := make(map[string]string)
674+
if tc.envvarDevices != "" {
675+
env[EnvVarNvidiaVisibleDevices] = tc.envvarDevices
676+
}
677+
for k, v := range tc.env {
678+
env[k] = v
679+
}
680+
604681
image, err := New(
605-
WithEnvMap(
606-
map[string]string{
607-
EnvVarNvidiaVisibleDevices: tc.envvarDevices,
608-
},
609-
),
682+
WithEnvMap(env),
610683
WithMounts(tc.mountDevices),
611684
WithPrivileged(tc.privileged),
612685
WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
613686
WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
687+
WithPreferredVisibleDevicesEnvVars(tc.preferredVisibleDeviceEnvVars...),
614688
)
615689
require.NoError(t, err)
616690
require.Equal(t, tc.expectedDevices, image.VisibleDevices())

internal/modifier/cdi_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func TestDeviceRequests(t *testing.T) {
9898
"another-prefix/bar": "example.com/device=baz",
9999
},
100100
},
101-
expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"},
101+
expectedDevices: []string{"example.com/device=baz", "example.com/device=bar"},
102102
},
103103
{
104104
description: "multiple matching annotations with duplicate devices",

internal/modifier/csv.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import (
3333
// NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
3434
// The modifications are defined by CSV MountSpecs.
3535
func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image.CUDA) (oci.SpecModifier, error) {
36-
if devices := container.VisibleDevicesFromEnvVar(); len(devices) == 0 {
36+
if devices := container.VisibleDevices(); len(devices) == 0 {
3737
logger.Infof("No modification required; no devices requested")
3838
return nil, nil
3939
}

internal/modifier/gated.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import (
3737
//
3838
// If not devices are selected, no changes are made.
3939
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
40-
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
40+
if devices := image.VisibleDevices(); len(devices) == 0 {
4141
logger.Infof("No modification required; no devices requested")
4242
return nil, nil
4343
}

internal/modifier/graphics.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ import (
2929

3030
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
3131
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
32-
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
33-
if required, reason := requiresGraphicsModifier(containerImage); !required {
34-
logger.Infof("No graphics modifier required: %v", reason)
32+
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, container image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
33+
devices, reason := requiresGraphicsModifier(container)
34+
if len(devices) == 0 {
35+
logger.Infof("No graphics modifier required; %v", reason)
3536
return nil, nil
3637
}
3738

@@ -48,7 +49,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI
4849
devRoot := driver.Root
4950
drmNodes, err := discover.NewDRMNodesDiscoverer(
5051
logger,
51-
containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices),
52+
image.NewVisibleDevices(devices...),
5253
devRoot,
5354
hookCreator,
5455
)
@@ -64,14 +65,15 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerI
6465
}
6566

6667
// requiresGraphicsModifier determines whether a graphics modifier is required.
67-
func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) {
68-
if devices := cudaImage.VisibleDevicesFromEnvVar(); len(devices) == 0 {
69-
return false, "no devices requested"
68+
func requiresGraphicsModifier(cudaImage image.CUDA) ([]string, string) {
69+
devices := cudaImage.VisibleDevices()
70+
if len(devices) == 0 {
71+
return nil, "no devices requested"
7072
}
7173

7274
if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) {
73-
return false, "no required capabilities requested"
75+
return nil, "no required capabilities requested"
7476
}
7577

76-
return true, ""
78+
return devices, ""
7779
}

0 commit comments

Comments
 (0)