Skip to content

Commit 980ca5d

Browse files
committed
Use functional options to construct runtime mode resolver
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 614e469 commit 980ca5d

File tree

8 files changed

+108
-20
lines changed

8 files changed

+108
-20
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,14 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
242242
}
243243
}
244244

245-
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
245+
func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
246+
hookConfig.Lock()
247+
defer hookConfig.Unlock()
248+
249+
if hookConfig.containerConfig != nil {
250+
return hookConfig.containerConfig
251+
}
252+
246253
var h HookState
247254
d := json.NewDecoder(os.Stdin)
248255
if err := d.Decode(&h); err != nil {
@@ -271,10 +278,13 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
271278
log.Panicln(err)
272279
}
273280

274-
return containerConfig{
281+
cc := containerConfig{
275282
Pid: h.Pid,
276283
Rootfs: s.Root.Path,
277284
Image: i,
278285
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
279286
}
287+
hookConfig.containerConfig = &cc
288+
289+
return hookConfig.containerConfig
280290
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ func TestGetNvidiaConfig(t *testing.T) {
487487
hookCfg := tc.hookConfig
488488
if hookCfg == nil {
489489
defaultConfig, _ := config.GetDefault()
490-
hookCfg = &hookConfig{defaultConfig}
490+
hookCfg = &hookConfig{Config: defaultConfig}
491491
}
492492
cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
493493
}

cmd/nvidia-container-runtime-hook/hook_config.go

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import (
77
"path"
88
"reflect"
99
"strings"
10+
"sync"
1011

1112
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
1213
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
14+
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
1315
)
1416

1517
const (
@@ -20,7 +22,9 @@ const (
2022
// hookConfig wraps the toolkit config.
2123
// This allows for functions to be defined on the local type.
2224
type hookConfig struct {
25+
sync.Mutex
2326
*config.Config
27+
containerConfig *containerConfig
2428
}
2529

2630
// loadConfig loads the required paths for the hook config.
@@ -55,7 +59,7 @@ func getHookConfig() (*hookConfig, error) {
5559
if err != nil {
5660
return nil, fmt.Errorf("failed to load config: %v", err)
5761
}
58-
config := &hookConfig{cfg}
62+
config := &hookConfig{Config: cfg}
5963

6064
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
6165
if config.SupportedDriverCapabilities == "all" {
@@ -73,8 +77,8 @@ func getHookConfig() (*hookConfig, error) {
7377

7478
// getConfigOption returns the toml config option associated with the
7579
// specified struct field.
76-
func (c hookConfig) getConfigOption(fieldName string) string {
77-
t := reflect.TypeOf(c)
80+
func (c *hookConfig) getConfigOption(fieldName string) string {
81+
t := reflect.TypeOf(&c)
7882
f, ok := t.FieldByName(fieldName)
7983
if !ok {
8084
return fieldName
@@ -127,3 +131,20 @@ func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string {
127131
}
128132
return []string{flag}
129133
}
134+
135+
func (c *hookConfig) assertModeIsLegacy() error {
136+
if c.NVIDIAContainerRuntimeHookConfig.SkipModeDetection {
137+
return nil
138+
}
139+
140+
mr := info.NewRuntimeModeResolver(
141+
info.WithLogger(&logInterceptor{}),
142+
info.WithImage(&c.containerConfig.Image),
143+
)
144+
145+
mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode)
146+
if mode == "legacy" {
147+
return nil
148+
}
149+
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
150+
}

cmd/nvidia-container-runtime-hook/hook_config_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
9090
}
9191
}
9292

93-
var cfg hookConfig
93+
var cfg *hookConfig
9494
getHookConfig := func() {
9595
c, _ := getHookConfig()
96-
cfg = *c
96+
cfg = c
9797
}
9898

9999
if tc.expectedPanic {

cmd/nvidia-container-runtime-hook/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func getCLIPath(config config.ContainerCLIConfig) string {
5555
}
5656

5757
// getRootfsPath returns an absolute path. We don't need to resolve symlinks for now.
58-
func getRootfsPath(config containerConfig) string {
58+
func getRootfsPath(config *containerConfig) string {
5959
rootfs, err := filepath.Abs(config.Rootfs)
6060
if err != nil {
6161
log.Panicln(err)
@@ -82,8 +82,8 @@ func doPrestart() {
8282
return
8383
}
8484

85-
if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
86-
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.")
85+
if err := hook.assertModeIsLegacy(); err != nil {
86+
log.Panicf("%v", err)
8787
}
8888

8989
rootfs := getRootfsPath(container)

internal/info/auto.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,75 @@ import (
2323
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2424
)
2525

26+
type RuntimeModeResolver interface {
27+
ResolveRuntimeMode(string) string
28+
}
29+
30+
type modeResolver struct {
31+
logger logger.Interface
32+
// TODO: This only needs to consider the requested devices.
33+
image *image.CUDA
34+
propertyExtractor info.PropertyExtractor
35+
}
36+
37+
type Option func(*modeResolver)
38+
39+
func WithLogger(logger logger.Interface) Option {
40+
return func(mr *modeResolver) {
41+
mr.logger = logger
42+
}
43+
}
44+
45+
func WithImage(image *image.CUDA) Option {
46+
return func(mr *modeResolver) {
47+
mr.image = image
48+
}
49+
}
50+
51+
func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option {
52+
return func(mr *modeResolver) {
53+
mr.propertyExtractor = propertyExtractor
54+
}
55+
}
56+
57+
func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver {
58+
r := &modeResolver{}
59+
for _, opt := range opts {
60+
opt(r)
61+
}
62+
if r.logger == nil {
63+
r.logger = &logger.NullLogger{}
64+
}
65+
66+
return r
67+
}
68+
2669
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
2770
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
28-
return resolveMode(logger, mode, image, nil)
71+
r := modeResolver{
72+
logger: logger,
73+
image: &image,
74+
propertyExtractor: nil,
75+
}
76+
return r.ResolveRuntimeMode(mode)
2977
}
3078

31-
func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) {
79+
func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode string) {
3280
if mode != "auto" {
33-
logger.Infof("Using requested mode '%s'", mode)
81+
m.logger.Infof("Using requested mode '%s'", mode)
3482
return mode
3583
}
3684
defer func() {
37-
logger.Infof("Auto-detected mode as '%v'", rmode)
85+
m.logger.Infof("Auto-detected mode as '%v'", rmode)
3886
}()
3987

40-
if image.OnlyFullyQualifiedCDIDevices() {
88+
if m.image.OnlyFullyQualifiedCDIDevices() {
4189
return "cdi"
4290
}
4391

4492
nvinfo := info.New(
45-
info.WithLogger(logger),
46-
info.WithPropertyExtractor(propertyExtractor),
93+
info.WithLogger(m.logger),
94+
info.WithPropertyExtractor(m.propertyExtractor),
4795
)
4896

4997
switch nvinfo.ResolvePlatform() {

internal/info/auto_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,12 @@ func TestResolveAutoMode(t *testing.T) {
251251
image.WithAcceptDeviceListAsVolumeMounts(true),
252252
image.WithAcceptEnvvarUnprivileged(true),
253253
)
254-
mode := resolveMode(logger, tc.mode, image, properties)
254+
mr := NewRuntimeModeResolver(
255+
WithLogger(logger),
256+
WithImage(&image),
257+
WithPropertyExtractor(properties),
258+
)
259+
mode := mr.ResolveRuntimeMode(tc.mode)
255260
require.EqualValues(t, tc.expectedMode, mode)
256261
})
257262
}

internal/runtime/runtime_factory.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe
136136
return "", nil, err
137137
}
138138

139-
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
139+
modeResolver := info.NewRuntimeModeResolver(
140+
info.WithLogger(logger),
141+
info.WithImage(&image),
142+
)
143+
mode := modeResolver.ResolveRuntimeMode(cfg.NVIDIAContainerRuntimeConfig.Mode)
140144
// We update the mode here so that we can continue passing just the config to other functions.
141145
cfg.NVIDIAContainerRuntimeConfig.Mode = mode
142146

0 commit comments

Comments
 (0)