Skip to content

Commit 0dddd5c

Browse files
authored
Merge pull request #910 from elezar/default-to-cdi
Use just-in-time CDI spec generation by default in the NVIDIA Container Runtime
2 parents 76b71a5 + 17c5d1d commit 0dddd5c

File tree

11 files changed

+221
-109
lines changed

11 files changed

+221
-109
lines changed

cmd/nvidia-container-runtime-hook/container_config.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,14 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
242242
}
243243
}
244244

245-
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
245+
func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
246+
hookConfig.Lock()
247+
defer hookConfig.Unlock()
248+
249+
if hookConfig.containerConfig != nil {
250+
return hookConfig.containerConfig
251+
}
252+
246253
var h HookState
247254
d := json.NewDecoder(os.Stdin)
248255
if err := d.Decode(&h); err != nil {
@@ -271,10 +278,13 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
271278
log.Panicln(err)
272279
}
273280

274-
return containerConfig{
281+
cc := containerConfig{
275282
Pid: h.Pid,
276283
Rootfs: s.Root.Path,
277284
Image: i,
278285
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
279286
}
287+
hookConfig.containerConfig = &cc
288+
289+
return hookConfig.containerConfig
280290
}

cmd/nvidia-container-runtime-hook/container_config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ func TestGetNvidiaConfig(t *testing.T) {
487487
hookCfg := tc.hookConfig
488488
if hookCfg == nil {
489489
defaultConfig, _ := config.GetDefault()
490-
hookCfg = &hookConfig{defaultConfig}
490+
hookCfg = &hookConfig{Config: defaultConfig}
491491
}
492492
cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
493493
}

cmd/nvidia-container-runtime-hook/hook_config.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ import (
77
"path"
88
"reflect"
99
"strings"
10+
"sync"
1011

1112
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
1213
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
14+
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
1315
)
1416

1517
const (
@@ -20,7 +22,9 @@ const (
2022
// hookConfig wraps the toolkit config.
2123
// This allows for functions to be defined on the local type.
2224
type hookConfig struct {
25+
sync.Mutex
2326
*config.Config
27+
containerConfig *containerConfig
2428
}
2529

2630
// loadConfig loads the required paths for the hook config.
@@ -55,7 +59,7 @@ func getHookConfig() (*hookConfig, error) {
5559
if err != nil {
5660
return nil, fmt.Errorf("failed to load config: %v", err)
5761
}
58-
config := &hookConfig{cfg}
62+
config := &hookConfig{Config: cfg}
5963

6064
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
6165
if config.SupportedDriverCapabilities == "all" {
@@ -73,8 +77,8 @@ func getHookConfig() (*hookConfig, error) {
7377

7478
// getConfigOption returns the toml config option associated with the
7579
// specified struct field.
76-
func (c hookConfig) getConfigOption(fieldName string) string {
77-
t := reflect.TypeOf(c)
80+
func (c *hookConfig) getConfigOption(fieldName string) string {
81+
t := reflect.TypeOf(&c)
7882
f, ok := t.FieldByName(fieldName)
7983
if !ok {
8084
return fieldName
@@ -127,3 +131,21 @@ func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string {
127131
}
128132
return []string{flag}
129133
}
134+
135+
func (c *hookConfig) assertModeIsLegacy() error {
136+
if c.NVIDIAContainerRuntimeHookConfig.SkipModeDetection {
137+
return nil
138+
}
139+
140+
mr := info.NewRuntimeModeResolver(
141+
info.WithLogger(&logInterceptor{}),
142+
info.WithImage(&c.containerConfig.Image),
143+
info.WithDefaultMode(info.LegacyRuntimeMode),
144+
)
145+
146+
mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode)
147+
if mode == "legacy" {
148+
return nil
149+
}
150+
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
151+
}

cmd/nvidia-container-runtime-hook/hook_config_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
9090
}
9191
}
9292

93-
var cfg hookConfig
93+
var cfg *hookConfig
9494
getHookConfig := func() {
9595
c, _ := getHookConfig()
96-
cfg = *c
96+
cfg = c
9797
}
9898

9999
if tc.expectedPanic {

cmd/nvidia-container-runtime-hook/main.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func getCLIPath(config config.ContainerCLIConfig) string {
5555
}
5656

5757
// getRootfsPath returns an absolute path. We don't need to resolve symlinks for now.
58-
func getRootfsPath(config containerConfig) string {
58+
func getRootfsPath(config *containerConfig) string {
5959
rootfs, err := filepath.Abs(config.Rootfs)
6060
if err != nil {
6161
log.Panicln(err)
@@ -82,8 +82,8 @@ func doPrestart() {
8282
return
8383
}
8484

85-
if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
86-
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.")
85+
if err := hook.assertModeIsLegacy(); err != nil {
86+
log.Panicf("%v", err)
8787
}
8888

8989
rootfs := getRootfsPath(container)

cmd/nvidia-container-runtime/main_test.go

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,10 @@ func TestGoodInput(t *testing.T) {
122122
err = cmdCreate.Run()
123123
require.NoError(t, err, "runtime should not return an error")
124124

125-
// Check config.json for NVIDIA prestart hook
125+
// Check config.json to ensure that the NVIDIA prestart was not inserted.
126126
spec, err = cfg.getRuntimeSpec()
127127
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
128-
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
129-
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
128+
require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
130129
}
131130

132131
// NVIDIA prestart hook already present in config file
@@ -168,11 +167,10 @@ func TestDuplicateHook(t *testing.T) {
168167
output, err := cmdCreate.CombinedOutput()
169168
require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output))
170169

171-
// Check config.json for NVIDIA prestart hook
170+
// Check config.json to ensure that the NVIDIA prestart hook was removed.
172171
spec, err = cfg.getRuntimeSpec()
173172
require.NoError(t, err, "should be no errors when reading and parsing spec from config.json")
174-
require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json")
175-
require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json")
173+
require.Empty(t, spec.Hooks, "there should be no hooks in config.json")
176174
}
177175

178176
// addNVIDIAHook is a basic wrapper for an addHookModifier that is used for
@@ -240,18 +238,3 @@ func (c testConfig) generateNewRuntimeSpec() error {
240238
}
241239
return nil
242240
}
243-
244-
// Return number of valid NVIDIA prestart hooks in runtime spec
245-
func nvidiaHookCount(hooks *specs.Hooks) int {
246-
if hooks == nil {
247-
return 0
248-
}
249-
250-
count := 0
251-
for _, hook := range hooks.Prestart {
252-
if strings.Contains(hook.Path, nvidiaHook) {
253-
count++
254-
}
255-
}
256-
return count
257-
}

internal/info/auto.go

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,34 +23,114 @@ import (
2323
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2424
)
2525

26+
// A RuntimeMode is used to select a specific mode of operation for the NVIDIA Container Runtime.
27+
type RuntimeMode string
28+
29+
const (
30+
// In LegacyRuntimeMode the nvidia-container-runtime injects the
31+
// nvidia-container-runtime-hook as a prestart hook into the incoming
32+
// container config. This hook invokes the nvidia-container-cli to perform
33+
// the required modifications to the container.
34+
LegacyRuntimeMode = RuntimeMode("legacy")
35+
// In CSVRuntimeMode the nvidia-container-runtime processes a set of CSV
36+
// files to determine which container modification are required. The
37+
// contents of these CSV files are used to generate an in-memory CDI
38+
// specification which is used to modify the container config.
39+
CSVRuntimeMode = RuntimeMode("csv")
40+
// In CDIRuntimeMode the nvidia-container-runtime applies the modifications
41+
// to the container config required for the requested CDI devices in the
42+
// same way that other CDI clients would.
43+
CDIRuntimeMode = RuntimeMode("cdi")
44+
// In JitCDIRuntimeMode the nvidia-container-runtime generates in-memory CDI
45+
// specifications for requested NVIDIA devices.
46+
JitCDIRuntimeMode = RuntimeMode("jit-cdi")
47+
)
48+
49+
type RuntimeModeResolver interface {
50+
ResolveRuntimeMode(string) RuntimeMode
51+
}
52+
53+
type modeResolver struct {
54+
logger logger.Interface
55+
// TODO: This only needs to consider the requested devices.
56+
image *image.CUDA
57+
propertyExtractor info.PropertyExtractor
58+
defaultMode RuntimeMode
59+
}
60+
61+
type Option func(*modeResolver)
62+
63+
func WithDefaultMode(defaultMode RuntimeMode) Option {
64+
return func(mr *modeResolver) {
65+
mr.defaultMode = defaultMode
66+
}
67+
}
68+
69+
func WithLogger(logger logger.Interface) Option {
70+
return func(mr *modeResolver) {
71+
mr.logger = logger
72+
}
73+
}
74+
75+
func WithImage(image *image.CUDA) Option {
76+
return func(mr *modeResolver) {
77+
mr.image = image
78+
}
79+
}
80+
81+
func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option {
82+
return func(mr *modeResolver) {
83+
mr.propertyExtractor = propertyExtractor
84+
}
85+
}
86+
87+
func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver {
88+
r := &modeResolver{
89+
defaultMode: JitCDIRuntimeMode,
90+
}
91+
for _, opt := range opts {
92+
opt(r)
93+
}
94+
if r.logger == nil {
95+
r.logger = &logger.NullLogger{}
96+
}
97+
98+
return r
99+
}
100+
26101
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
27-
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
28-
return resolveMode(logger, mode, image, nil)
102+
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode RuntimeMode) {
103+
r := modeResolver{
104+
logger: logger,
105+
image: &image,
106+
propertyExtractor: nil,
107+
}
108+
return r.ResolveRuntimeMode(mode)
29109
}
30110

31-
func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) {
111+
func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode RuntimeMode) {
32112
if mode != "auto" {
33-
logger.Infof("Using requested mode '%s'", mode)
34-
return mode
113+
m.logger.Infof("Using requested mode '%s'", mode)
114+
return RuntimeMode(mode)
35115
}
36116
defer func() {
37-
logger.Infof("Auto-detected mode as '%v'", rmode)
117+
m.logger.Infof("Auto-detected mode as '%v'", rmode)
38118
}()
39119

40-
if image.OnlyFullyQualifiedCDIDevices() {
41-
return "cdi"
120+
if m.image.OnlyFullyQualifiedCDIDevices() {
121+
return CDIRuntimeMode
42122
}
43123

44124
nvinfo := info.New(
45-
info.WithLogger(logger),
46-
info.WithPropertyExtractor(propertyExtractor),
125+
info.WithLogger(m.logger),
126+
info.WithPropertyExtractor(m.propertyExtractor),
47127
)
48128

49129
switch nvinfo.ResolvePlatform() {
50130
case info.PlatformNVML, info.PlatformWSL:
51-
return "legacy"
131+
return m.defaultMode
52132
case info.PlatformTegra:
53-
return "csv"
133+
return CSVRuntimeMode
54134
}
55-
return "legacy"
135+
return m.defaultMode
56136
}

0 commit comments

Comments
 (0)