Skip to content

Commit 35b23c5

Browse files
cdesiniotiselezar
authored andcommitted
Accept device.Identifiers for requesting CDI specs
This change moves from using strings to useing device.Identifiers as input for requesting CDI specifications for specific devices. Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com> Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent a442a5e commit 35b23c5

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

pkg/nvcdi/lib-nvml.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,17 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
8080
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
8181
// the provided identifiers, where an identifier is an index or UUID of a valid
8282
// GPU device.
83-
func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) {
83+
// Deprecated: Use GetDeviceSpecsBy instead.
84+
func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
85+
var identifiers []device.Identifier
86+
for _, id := range ids {
87+
identifiers = append(identifiers, device.Identifier(id))
88+
}
89+
return l.GetDeviceSpecsBy(identifiers...)
90+
}
91+
92+
// GetDeviceSpecsBy is not supported for the gdslib specs.
93+
func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) {
8494
for _, id := range identifiers {
8595
if id == "all" {
8696
return l.GetAllDeviceSpecs()
@@ -109,7 +119,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
109119
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
110120
}
111121
deviceSpec := specs.Device{
112-
Name: identifiers[i],
122+
Name: string(identifiers[i]),
113123
ContainerEdits: *deviceEdits.ContainerEdits,
114124
}
115125
deviceSpecs = append(deviceSpecs, deviceSpec)
@@ -119,7 +129,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
119129
}
120130

121131
// TODO: move this to go-nvlib?
122-
func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, error) {
132+
func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) {
123133
var devices []nvml.Device
124134
for _, id := range identifiers {
125135
dev, err := l.getNVMLDeviceByID(id)
@@ -131,25 +141,24 @@ func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, erro
131141
return devices, nil
132142
}
133143

134-
func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) {
144+
func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
135145
var err error
136-
devID := device.Identifier(id)
137146

138-
if devID.IsUUID() {
139-
return l.nvmllib.DeviceGetHandleByUUID(id)
147+
if id.IsUUID() {
148+
return l.nvmllib.DeviceGetHandleByUUID(string(id))
140149
}
141150

142-
if devID.IsGpuIndex() {
143-
if idx, err := strconv.Atoi(id); err == nil {
151+
if id.IsGpuIndex() {
152+
if idx, err := strconv.Atoi(string(id)); err == nil {
144153
return l.nvmllib.DeviceGetHandleByIndex(idx)
145154
}
146155
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
147156
}
148157

149-
if devID.IsMigIndex() {
158+
if id.IsMigIndex() {
150159
var gpuIdx, migIdx int
151160
var parent nvml.Device
152-
split := strings.SplitN(id, ":", 2)
161+
split := strings.SplitN(string(id), ":", 2)
153162
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
154163
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
155164
}

0 commit comments

Comments
 (0)