Skip to content

Commit a1d48b6

Browse files
committed
Simplify nvcdi interface
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent fd7d222 commit a1d48b6

File tree

9 files changed

+87
-257
lines changed

9 files changed

+87
-257
lines changed

pkg/nvcdi/api.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package nvcdi
1818

1919
import (
20-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2120
"tags.cncf.io/container-device-interface/pkg/cdi"
2221
"tags.cncf.io/container-device-interface/specs-go"
2322

@@ -29,12 +28,9 @@ import (
2928
type Interface interface {
3029
SpecGenerator
3130
GetCommonEdits() (*cdi.ContainerEdits, error)
32-
GetAllDeviceSpecs() ([]specs.Device, error)
33-
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)
34-
GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error)
35-
GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error)
36-
GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error)
3731
GetDeviceSpecsByID(...string) ([]specs.Device, error)
32+
// Deprecated: GetAllDeviceSpecs is deprecated. Use GetDeviceSpecsByID("all") instead.
33+
GetAllDeviceSpecs() ([]specs.Device, error)
3834
}
3935

4036
// A SpecGenerator is used to generate a complete CDI spec for a collected set

pkg/nvcdi/gds.go

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,19 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2726
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
28-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
2927
)
3028

3129
type gdslib nvcdilib
3230

33-
var _ Interface = (*gdslib)(nil)
31+
var _ wrapped = (*gdslib)(nil)
3432

35-
// GetAllDeviceSpecs returns the device specs for all available devices.
36-
func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) {
33+
// GetDeviceSpecsByID returns the device specs for the specified devices.
34+
func (l *gdslib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
3735
discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot)
3836
if err != nil {
3937
return nil, fmt.Errorf("failed to create GPUDirect Storage discoverer: %v", err)
@@ -55,36 +53,3 @@ func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) {
5553
func (l *gdslib) GetCommonEdits() (*cdi.ContainerEdits, error) {
5654
return edits.FromDiscoverer(discover.None{})
5755
}
58-
59-
// GetSpec is unsppported for the gdslib specs.
60-
// gdslib is typically wrapped by a spec that implements GetSpec.
61-
func (l *gdslib) GetSpec(...string) (spec.Interface, error) {
62-
return nil, fmt.Errorf("GetSpec is not supported")
63-
}
64-
65-
// GetGPUDeviceEdits is unsupported for the gdslib specs
66-
func (l *gdslib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
67-
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported")
68-
}
69-
70-
// GetGPUDeviceSpecs is unsupported for the gdslib specs
71-
func (l *gdslib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) {
72-
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported")
73-
}
74-
75-
// GetMIGDeviceEdits is unsupported for the gdslib specs
76-
func (l *gdslib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
77-
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported")
78-
}
79-
80-
// GetMIGDeviceSpecs is unsupported for the gdslib specs
81-
func (l *gdslib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
82-
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported")
83-
}
84-
85-
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
86-
// the provided identifiers, where an identifier is an index or UUID of a valid
87-
// GPU device.
88-
func (l *gdslib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
89-
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported")
90-
}

pkg/nvcdi/lib-csv.go

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,29 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2726
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2827
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
29-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3028
)
3129

3230
type csvlib nvcdilib
3331

34-
var _ Interface = (*csvlib)(nil)
32+
var _ wrapped = (*csvlib)(nil)
3533

36-
// GetSpec should not be called for wsllib
37-
func (l *csvlib) GetSpec(...string) (spec.Interface, error) {
38-
return nil, fmt.Errorf("unexpected call to csvlib.GetSpec()")
39-
}
34+
// GetDeviceSpecsByID returns the device specs for the specified devices.
35+
func (l *csvlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
36+
for _, id := range ids {
37+
switch id {
38+
case "all":
39+
case "0":
40+
default:
41+
return nil, fmt.Errorf("unsupported device id: %v", id)
42+
}
43+
}
4044

41-
// GetAllDeviceSpecs returns the device specs for all available devices.
42-
func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
4345
d, err := tegra.New(
4446
tegra.WithLogger(l.logger),
4547
tegra.WithDriverRoot(l.driverRoot),
@@ -76,33 +78,5 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
7678

7779
// GetCommonEdits generates a CDI specification that can be used for ANY devices
7880
func (l *csvlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
79-
d := discover.None{}
80-
return edits.FromDiscoverer(d)
81-
}
82-
83-
// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices
84-
func (l *csvlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
85-
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported for CSV files")
86-
}
87-
88-
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
89-
func (l *csvlib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
90-
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported for CSV files")
91-
}
92-
93-
// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices
94-
func (l *csvlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
95-
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported for CSV files")
96-
}
97-
98-
// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'.
99-
func (l *csvlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
100-
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported for CSV files")
101-
}
102-
103-
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
104-
// the provided identifiers, where an identifier is an index or UUID of a valid
105-
// GPU device.
106-
func (l *csvlib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
107-
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported for CSV files")
81+
return edits.FromDiscoverer(discover.None{})
10882
}

pkg/nvcdi/lib-imex.go

Lines changed: 40 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,64 +25,34 @@ import (
2525
"tags.cncf.io/container-device-interface/pkg/cdi"
2626
"tags.cncf.io/container-device-interface/specs-go"
2727

28-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
29-
3028
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
3129
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
32-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3330
)
3431

3532
type imexlib nvcdilib
3633

37-
var _ Interface = (*imexlib)(nil)
34+
var _ wrapped = (*imexlib)(nil)
3835

3936
const (
4037
classImexChannel = "imex-channel"
4138
)
4239

43-
// GetSpec should not be called for imexlib.
44-
func (l *imexlib) GetSpec(...string) (spec.Interface, error) {
45-
return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()")
46-
}
47-
48-
// GetAllDeviceSpecs returns the device specs for all available devices.
49-
func (l *imexlib) GetAllDeviceSpecs() ([]specs.Device, error) {
50-
channelsDiscoverer := discover.NewCharDeviceDiscoverer(
51-
l.logger,
52-
l.devRoot,
53-
[]string{"/dev/nvidia-caps-imex-channels/channel*"},
54-
)
55-
56-
channels, err := channelsDiscoverer.Devices()
57-
if err != nil {
58-
return nil, err
59-
}
60-
61-
var channelIDs []string
62-
for _, channel := range channels {
63-
channelIDs = append(channelIDs, filepath.Base(channel.Path))
64-
}
65-
66-
return l.GetDeviceSpecsByID(channelIDs...)
67-
}
68-
6940
// GetCommonEdits returns an empty set of edits for IMEX devices.
7041
func (l *imexlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
7142
return edits.FromDiscoverer(discover.None{})
7243
}
7344

7445
// GetDeviceSpecsByID returns the CDI device specs for the IMEX channels specified.
7546
func (l *imexlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
47+
channelsIDs, err := l.getChannelIDs(ids...)
48+
if err != nil {
49+
return nil, err
50+
}
7651
var deviceSpecs []specs.Device
77-
for _, id := range ids {
78-
trimmed := strings.TrimPrefix(id, "channel")
79-
_, err := strconv.ParseUint(trimmed, 10, 64)
80-
if err != nil {
81-
return nil, fmt.Errorf("invalid channel ID %v: %w", id, err)
82-
}
83-
path := "/dev/nvidia-caps-imex-channels/channel" + trimmed
52+
for _, id := range channelsIDs {
53+
path := "/dev/nvidia-caps-imex-channels/channel" + id
8454
deviceSpec := specs.Device{
85-
Name: trimmed,
55+
Name: id,
8656
ContainerEdits: specs.ContainerEdits{
8757
DeviceNodes: []*specs.DeviceNode{
8858
{
@@ -97,22 +67,40 @@ func (l *imexlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
9767
return deviceSpecs, nil
9868
}
9969

100-
// GetGPUDeviceEdits is unsupported for the imexlib specs
101-
func (l *imexlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
102-
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported")
70+
func (l *imexlib) getChannelIDs(ids ...string) ([]string, error) {
71+
var channelIDs []string
72+
for _, id := range ids {
73+
trimmed := strings.TrimPrefix(id, "channel")
74+
if trimmed == "all" {
75+
return l.getAllChannelIDs()
76+
}
77+
_, err := strconv.ParseUint(trimmed, 10, 64)
78+
if err != nil {
79+
return nil, fmt.Errorf("invalid channel ID %v: %w", id, err)
80+
}
81+
channelIDs = append(channelIDs, trimmed)
82+
}
83+
return channelIDs, nil
10384
}
10485

105-
// GetGPUDeviceSpecs is unsupported for the imexlib specs
106-
func (l *imexlib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) {
107-
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported")
108-
}
86+
// getAllChannelIDs returns the device IDs for all available IMEX channels.
87+
func (l *imexlib) getAllChannelIDs() ([]string, error) {
88+
channelsDiscoverer := discover.NewCharDeviceDiscoverer(
89+
l.logger,
90+
l.devRoot,
91+
[]string{"/dev/nvidia-caps-imex-channels/channel*"},
92+
)
10993

110-
// GetMIGDeviceEdits is unsupported for the imexlib specs
111-
func (l *imexlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
112-
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported")
113-
}
94+
channels, err := channelsDiscoverer.Devices()
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
var channelIDs []string
100+
for _, channel := range channels {
101+
channelID := filepath.Base(channel.Path)
102+
channelIDs = append(channelIDs, strings.TrimPrefix(channelID, "channel"))
103+
}
114104

115-
// GetMIGDeviceSpecs is unsupported for the imexlib specs
116-
func (l *imexlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
117-
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported")
105+
return channelIDs, nil
118106
}

pkg/nvcdi/lib-wsl.go

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,18 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
27-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
2826
)
2927

3028
type wsllib nvcdilib
3129

32-
var _ Interface = (*wsllib)(nil)
30+
var _ wrapped = (*wsllib)(nil)
3331

34-
// GetSpec should not be called for wsllib
35-
func (l *wsllib) GetSpec(...string) (spec.Interface, error) {
36-
return nil, fmt.Errorf("unexpected call to wsllib.GetSpec()")
37-
}
38-
39-
// GetAllDeviceSpecs returns the device specs for all available devices.
40-
func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
32+
// GetDeviceSpecsByID returns the device specs for the specified devices.
33+
func (l *wsllib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
4134
device := newDXGDeviceDiscoverer(l.logger, l.devRoot)
4235
deviceEdits, err := edits.FromDiscoverer(device)
4336
if err != nil {
@@ -61,30 +54,3 @@ func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
6154

6255
return edits.FromDiscoverer(driver)
6356
}
64-
65-
// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices
66-
func (l *wsllib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
67-
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported on WSL")
68-
}
69-
70-
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
71-
func (l *wsllib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
72-
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported on WSL")
73-
}
74-
75-
// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices
76-
func (l *wsllib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
77-
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported on WSL")
78-
}
79-
80-
// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'.
81-
func (l *wsllib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
82-
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported on WSL")
83-
}
84-
85-
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
86-
// the provided identifiers, where an identifier is an index or UUID of a valid
87-
// GPU device.
88-
func (l *wsllib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
89-
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported on WSL")
90-
}

pkg/nvcdi/lib.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ func New(opts ...Option) (Interface, error) {
118118
)
119119
}
120120

121-
var lib Interface
121+
var lib wrapped
122122
switch l.resolveMode() {
123123
case ModeCSV:
124124
if len(l.csvFiles) == 0 {
@@ -162,7 +162,7 @@ func New(opts ...Option) (Interface, error) {
162162
)
163163

164164
w := wrapper{
165-
Interface: lib,
165+
wrapped: lib,
166166
vendor: l.vendor,
167167
class: l.class,
168168
mergedDeviceOptions: l.mergedDeviceOptions,

0 commit comments

Comments
 (0)