Skip to content

Commit c076436

Browse files
committed
[no-relnote] Refactor config handling for hook
This change removes indirect calls to get the default config from the nvidia-container-runtime-hook. Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 1afada7 commit c076436

File tree

5 files changed

+74
-54
lines changed

5 files changed

+74
-54
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []stri
157157
return containerImage.VisibleDevicesFromEnvVar()
158158
}
159159

160-
func getDevices(hookConfig *HookConfig, image image.CUDA, privileged bool) []string {
160+
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
161161
// If enabled, try and get the device list from volume mounts first
162162
if hookConfig.AcceptDeviceListAsVolumeMounts {
163163
devices := image.VisibleDevicesFromMounts()
@@ -197,7 +197,7 @@ func getMigDevices(image image.CUDA, envvar string) *string {
197197
return &devices
198198
}
199199

200-
func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool) []string {
200+
func (hookConfig *hookConfig) getImexChannels(image image.CUDA, privileged bool) []string {
201201
// If enabled, try and get the device list from volume mounts first
202202
if hookConfig.AcceptDeviceListAsVolumeMounts {
203203
devices := image.ImexChannelsFromMounts()
@@ -217,10 +217,10 @@ func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool)
217217
return nil
218218
}
219219

220-
func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
220+
func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
221221
// We use the default driver capabilities by default. This is filtered to only include the
222222
// supported capabilities
223-
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
223+
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
224224

225225
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
226226

@@ -244,10 +244,10 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo
244244
return capabilities
245245
}
246246

247-
func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) *nvidiaConfig {
247+
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
248248
legacyImage := image.IsLegacy()
249249

250-
devices := getDevices(hookConfig, image, privileged)
250+
devices := hookConfig.getDevices(image, privileged)
251251
if len(devices) == 0 {
252252
// empty devices means this is not a GPU container.
253253
return nil
@@ -269,7 +269,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool)
269269
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
270270
}
271271

272-
imexChannels := getImexChannels(hookConfig, image, privileged)
272+
imexChannels := hookConfig.getImexChannels(image, privileged)
273273

274274
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
275275

@@ -288,7 +288,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool)
288288
}
289289
}
290290

291-
func getContainerConfig(hook HookConfig) (config containerConfig) {
291+
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
292292
var h HookState
293293
d := json.NewDecoder(os.Stdin)
294294
if err := d.Decode(&h); err != nil {
@@ -305,7 +305,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
305305
image, err := image.New(
306306
image.WithEnv(s.Process.Env),
307307
image.WithMounts(s.Mounts),
308-
image.WithDisableRequire(hook.DisableRequire),
308+
image.WithDisableRequire(hookConfig.DisableRequire),
309309
)
310310
if err != nil {
311311
log.Panicln(err)
@@ -316,6 +316,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
316316
Pid: h.Pid,
317317
Rootfs: s.Root.Path,
318318
Image: image,
319-
Nvidia: getNvidiaConfig(&hook, image, privileged),
319+
Nvidia: hookConfig.getNvidiaConfig(image, privileged),
320320
}
321321
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/opencontainers/runtime-spec/specs-go"
88
"github.com/stretchr/testify/require"
99

10+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
1011
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
1112
)
1213

@@ -15,7 +16,7 @@ func TestGetNvidiaConfig(t *testing.T) {
1516
description string
1617
env map[string]string
1718
privileged bool
18-
hookConfig *HookConfig
19+
hookConfig *hookConfig
1920
expectedConfig *nvidiaConfig
2021
expectedPanic bool
2122
}{
@@ -394,8 +395,10 @@ func TestGetNvidiaConfig(t *testing.T) {
394395
image.EnvVarNvidiaDriverCapabilities: "all",
395396
},
396397
privileged: true,
397-
hookConfig: &HookConfig{
398-
SupportedDriverCapabilities: "video,display",
398+
hookConfig: &hookConfig{
399+
Config: &config.Config{
400+
SupportedDriverCapabilities: "video,display",
401+
},
399402
},
400403
expectedConfig: &nvidiaConfig{
401404
Devices: []string{"all"},
@@ -409,8 +412,10 @@ func TestGetNvidiaConfig(t *testing.T) {
409412
image.EnvVarNvidiaDriverCapabilities: "video,display",
410413
},
411414
privileged: true,
412-
hookConfig: &HookConfig{
413-
SupportedDriverCapabilities: "video,display,compute,utility",
415+
hookConfig: &hookConfig{
416+
Config: &config.Config{
417+
SupportedDriverCapabilities: "video,display,compute,utility",
418+
},
414419
},
415420
expectedConfig: &nvidiaConfig{
416421
Devices: []string{"all"},
@@ -423,8 +428,10 @@ func TestGetNvidiaConfig(t *testing.T) {
423428
image.EnvVarNvidiaVisibleDevices: "all",
424429
},
425430
privileged: true,
426-
hookConfig: &HookConfig{
427-
SupportedDriverCapabilities: "video,display,utility,compute",
431+
hookConfig: &hookConfig{
432+
Config: &config.Config{
433+
SupportedDriverCapabilities: "video,display,utility,compute",
434+
},
428435
},
429436
expectedConfig: &nvidiaConfig{
430437
Devices: []string{"all"},
@@ -438,9 +445,11 @@ func TestGetNvidiaConfig(t *testing.T) {
438445
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
439446
},
440447
privileged: true,
441-
hookConfig: &HookConfig{
442-
SwarmResource: "DOCKER_SWARM_RESOURCE",
443-
SupportedDriverCapabilities: "video,display,utility,compute",
448+
hookConfig: &hookConfig{
449+
Config: &config.Config{
450+
SwarmResource: "DOCKER_SWARM_RESOURCE",
451+
SupportedDriverCapabilities: "video,display,utility,compute",
452+
},
444453
},
445454
expectedConfig: &nvidiaConfig{
446455
Devices: []string{"GPU1", "GPU2"},
@@ -454,9 +463,11 @@ func TestGetNvidiaConfig(t *testing.T) {
454463
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
455464
},
456465
privileged: true,
457-
hookConfig: &HookConfig{
458-
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
459-
SupportedDriverCapabilities: "video,display,utility,compute",
466+
hookConfig: &hookConfig{
467+
Config: &config.Config{
468+
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
469+
SupportedDriverCapabilities: "video,display,utility,compute",
470+
},
460471
},
461472
expectedConfig: &nvidiaConfig{
462473
Devices: []string{"GPU1", "GPU2"},
@@ -470,14 +481,14 @@ func TestGetNvidiaConfig(t *testing.T) {
470481
image.WithEnvMap(tc.env),
471482
)
472483
// Wrap the call to getNvidiaConfig() in a closure.
473-
var config *nvidiaConfig
484+
var cfg *nvidiaConfig
474485
getConfig := func() {
475-
hookConfig := tc.hookConfig
476-
if hookConfig == nil {
477-
defaultConfig, _ := getDefaultHookConfig()
478-
hookConfig = &defaultConfig
486+
hookCfg := tc.hookConfig
487+
if hookCfg == nil {
488+
defaultConfig, _ := config.GetDefault()
489+
hookCfg = &hookConfig{defaultConfig}
479490
}
480-
config = getNvidiaConfig(hookConfig, image, tc.privileged)
491+
cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
481492
}
482493

483494
// For any tests that are expected to panic, make sure they do.
@@ -491,18 +502,18 @@ func TestGetNvidiaConfig(t *testing.T) {
491502

492503
// And start comparing the test results to the expected results.
493504
if tc.expectedConfig == nil {
494-
require.Nil(t, config, tc.description)
505+
require.Nil(t, cfg, tc.description)
495506
return
496507
}
497508

498-
require.NotNil(t, config, tc.description)
509+
require.NotNil(t, cfg, tc.description)
499510

500-
require.Equal(t, tc.expectedConfig.Devices, config.Devices)
501-
require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices)
502-
require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices)
503-
require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities)
511+
require.Equal(t, tc.expectedConfig.Devices, cfg.Devices)
512+
require.Equal(t, tc.expectedConfig.MigConfigDevices, cfg.MigConfigDevices)
513+
require.Equal(t, tc.expectedConfig.MigMonitorDevices, cfg.MigMonitorDevices)
514+
require.Equal(t, tc.expectedConfig.DriverCapabilities, cfg.DriverCapabilities)
504515

505-
require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements)
516+
require.ElementsMatch(t, tc.expectedConfig.Requirements, cfg.Requirements)
506517
})
507518
}
508519
}
@@ -612,10 +623,11 @@ func TestDeviceListSourcePriority(t *testing.T) {
612623
),
613624
image.WithMounts(tc.mountDevices),
614625
)
615-
hookConfig, _ := getDefaultHookConfig()
616-
hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
617-
hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
618-
devices = getDevices(&hookConfig, image, tc.privileged)
626+
defaultConfig, _ := config.GetDefault()
627+
cfg := &hookConfig{defaultConfig}
628+
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
629+
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
630+
devices = cfg.getDevices(image, tc.privileged)
619631
}
620632

621633
// For all other tests, just grab the devices and check the results
@@ -940,8 +952,10 @@ func TestGetDriverCapabilities(t *testing.T) {
940952
t.Run(tc.description, func(t *testing.T) {
941953
var capabilities string
942954

943-
c := HookConfig{
944-
SupportedDriverCapabilities: tc.supportedCapabilities,
955+
c := hookConfig{
956+
Config: &config.Config{
957+
SupportedDriverCapabilities: tc.supportedCapabilities,
958+
},
945959
}
946960

947961
image, _ := image.New(

cmd/nvidia-container-runtime-hook/hook_config.go

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ const (
1717
driverPath = "/run/nvidia/driver"
1818
)
1919

20-
// HookConfig : options for the nvidia-container-runtime-hook.
21-
type HookConfig config.Config
20+
// hookConfig wraps the toolkit config.
21+
// This allows for functions to be defined on the local type.
22+
type hookConfig struct {
23+
*config.Config
24+
}
2225

2326
// loadConfig loads the required paths for the hook config.
2427
func loadConfig() (*config.Config, error) {
@@ -47,12 +50,12 @@ func loadConfig() (*config.Config, error) {
4750
return config.GetDefault()
4851
}
4952

50-
func getHookConfig() (*HookConfig, error) {
53+
func getHookConfig() (*hookConfig, error) {
5154
cfg, err := loadConfig()
5255
if err != nil {
5356
return nil, fmt.Errorf("failed to load config: %v", err)
5457
}
55-
config := (*HookConfig)(cfg)
58+
config := &hookConfig{cfg}
5659

5760
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
5861
if config.SupportedDriverCapabilities == "all" {
@@ -70,7 +73,7 @@ func getHookConfig() (*HookConfig, error) {
7073

7174
// getConfigOption returns the toml config option associated with the
7275
// specified struct field.
73-
func (c HookConfig) getConfigOption(fieldName string) string {
76+
func (c hookConfig) getConfigOption(fieldName string) string {
7477
t := reflect.TypeOf(c)
7578
f, ok := t.FieldByName(fieldName)
7679
if !ok {
@@ -84,7 +87,7 @@ func (c HookConfig) getConfigOption(fieldName string) string {
8487
}
8588

8689
// getSwarmResourceEnvvars returns the swarm resource envvars for the config.
87-
func (c *HookConfig) getSwarmResourceEnvvars() []string {
90+
func (c *hookConfig) getSwarmResourceEnvvars() []string {
8891
if c.SwarmResource == "" {
8992
return nil
9093
}

cmd/nvidia-container-runtime-hook/hook_config_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/stretchr/testify/require"
2525

26+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
2627
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
2728
)
2829

@@ -89,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
8990
}
9091
}
9192

92-
var config HookConfig
93+
var cfg hookConfig
9394
getHookConfig := func() {
9495
c, _ := getHookConfig()
95-
config = *c
96+
cfg = *c
9697
}
9798

9899
if tc.expectedPanic {
@@ -102,7 +103,7 @@ func TestGetHookConfig(t *testing.T) {
102103

103104
getHookConfig()
104105

105-
require.EqualValues(t, tc.expectedDriverCapabilities, config.SupportedDriverCapabilities)
106+
require.EqualValues(t, tc.expectedDriverCapabilities, cfg.SupportedDriverCapabilities)
106107
})
107108
}
108109
}
@@ -144,8 +145,10 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {
144145

145146
for i, tc := range testCases {
146147
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
147-
c := &HookConfig{
148-
SwarmResource: tc.value,
148+
c := &hookConfig{
149+
Config: &config.Config{
150+
SwarmResource: tc.value,
151+
},
149152
}
150153

151154
envvars := c.getSwarmResourceEnvvars()

cmd/nvidia-container-runtime-hook/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func doPrestart() {
7575
}
7676
cli := hook.NVIDIAContainerCLIConfig
7777

78-
container := getContainerConfig(*hook)
78+
container := hook.getContainerConfig()
7979
nvidia := container.Nvidia
8080
if nvidia == nil {
8181
// Not a GPU container, nothing to do.

0 commit comments

Comments
 (0)