Skip to content

Commit 99cc0ae

Browse files
author
Evan Lezar
committed
Merge branch 'pass-image-to-csv-constructor' into 'main'
Pass image when constructing CSV modifier See merge request nvidia/container-toolkit/container-toolkit!451
2 parents f08e48e + cca343a commit 99cc0ae

File tree

3 files changed

+20
-53
lines changed

3 files changed

+20
-53
lines changed

internal/modifier/csv.go

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,7 @@ const (
4545

4646
// NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
4747
// The modifications are defined by CSV MountSpecs.
48-
func NewCSVModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
49-
rawSpec, err := ociSpec.Load()
50-
if err != nil {
51-
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
52-
}
53-
54-
image, err := image.NewCUDAImageFromSpec(rawSpec)
55-
if err != nil {
56-
return nil, err
57-
}
58-
48+
func NewCSVModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) {
5949
if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 {
6050
logger.Infof("No modification required; no devices requested")
6151
return nil, nil

internal/modifier/csv_test.go

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
package modifier
1818

1919
import (
20-
"fmt"
2120
"testing"
2221

2322
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
24-
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
23+
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
2524
"github.com/opencontainers/runtime-spec/specs-go"
2625
testlog "github.com/sirupsen/logrus/hooks/test"
2726
"github.com/stretchr/testify/require"
@@ -31,54 +30,32 @@ func TestNewCSVModifier(t *testing.T) {
3130
logger, _ := testlog.NewNullLogger()
3231

3332
testCases := []struct {
34-
description string
35-
cfg *config.Config
36-
spec oci.Spec
37-
visibleDevices string
38-
expectedError error
39-
expectedNil bool
33+
description string
34+
cfg *config.Config
35+
image image.CUDA
36+
expectedError error
37+
expectedNil bool
4038
}{
4139
{
42-
description: "spec load error returns error",
43-
spec: &oci.SpecMock{
44-
LoadFunc: func() (*specs.Spec, error) {
45-
return nil, fmt.Errorf("load failed")
46-
},
47-
},
48-
expectedError: fmt.Errorf("load failed"),
40+
description: "visible devices not set returns nil",
41+
image: image.CUDA{},
42+
expectedNil: true,
4943
},
5044
{
51-
description: "visible devices not set returns nil",
52-
visibleDevices: "NOT_SET",
53-
expectedNil: true,
45+
description: "visible devices empty returns nil",
46+
image: image.CUDA{"NVIDIA_VISIBLE_DEVICES": ""},
47+
expectedNil: true,
5448
},
5549
{
56-
description: "visible devices empty returns nil",
57-
visibleDevices: "",
58-
expectedNil: true,
59-
},
60-
{
61-
description: "visible devices 'void' returns nil",
62-
visibleDevices: "void",
63-
expectedNil: true,
50+
description: "visible devices 'void' returns nil",
51+
image: image.CUDA{"NVIDIA_VISIBLE_DEVICES": "void"},
52+
expectedNil: true,
6453
},
6554
}
6655

6756
for _, tc := range testCases {
6857
t.Run(tc.description, func(t *testing.T) {
69-
spec := tc.spec
70-
if spec == nil {
71-
spec = &oci.SpecMock{
72-
LookupEnvFunc: func(s string) (string, bool) {
73-
if tc.visibleDevices != "NOT_SET" && s == visibleDevicesEnvvar {
74-
return tc.visibleDevices, true
75-
}
76-
return "", false
77-
},
78-
}
79-
}
80-
81-
m, err := NewCSVModifier(logger, tc.cfg, spec)
58+
m, err := NewCSVModifier(logger, tc.cfg, tc.image)
8259
if tc.expectedError != nil {
8360
require.Error(t, err)
8461
} else {

internal/runtime/runtime_factory.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
7373
}
7474

7575
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
76-
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec)
76+
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image)
7777
if err != nil {
7878
return nil, err
7979
}
@@ -106,12 +106,12 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
106106
return modifiers, nil
107107
}
108108

109-
func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
109+
func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) {
110110
switch mode {
111111
case "legacy":
112112
return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
113113
case "csv":
114-
return modifier.NewCSVModifier(logger, cfg, ociSpec)
114+
return modifier.NewCSVModifier(logger, cfg, image)
115115
case "cdi":
116116
return modifier.NewCDIModifier(logger, cfg, ociSpec)
117117
}

0 commit comments

Comments
 (0)