Skip to content

Commit f4d0cfb

Browse files
authored
Merge pull request #318 from cdesiniotis/update-func-signature
Get device specs by Identifier
2 parents 0dc87e5 + 35b23c5 commit f4d0cfb

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

pkg/nvcdi/lib-nvml.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,17 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
8080
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
8181
// the provided identifiers, where an identifier is an index or UUID of a valid
8282
// GPU device.
83-
func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) {
83+
// Deprecated: Use GetDeviceSpecsBy instead.
84+
func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
85+
var identifiers []device.Identifier
86+
for _, id := range ids {
87+
identifiers = append(identifiers, device.Identifier(id))
88+
}
89+
return l.GetDeviceSpecsBy(identifiers...)
90+
}
91+
92+
// GetDeviceSpecsBy is not supported for the gdslib specs.
93+
func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) {
8494
for _, id := range identifiers {
8595
if id == "all" {
8696
return l.GetAllDeviceSpecs()
@@ -109,7 +119,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
109119
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
110120
}
111121
deviceSpec := specs.Device{
112-
Name: identifiers[i],
122+
Name: string(identifiers[i]),
113123
ContainerEdits: *deviceEdits.ContainerEdits,
114124
}
115125
deviceSpecs = append(deviceSpecs, deviceSpec)
@@ -119,7 +129,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
119129
}
120130

121131
// TODO: move this to go-nvlib?
122-
func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, error) {
132+
func (l *nvmllib) getNVMLDevicesByID(identifiers ...device.Identifier) ([]nvml.Device, error) {
123133
var devices []nvml.Device
124134
for _, id := range identifiers {
125135
dev, err := l.getNVMLDeviceByID(id)
@@ -131,25 +141,24 @@ func (l *nvmllib) getNVMLDevicesByID(identifiers ...string) ([]nvml.Device, erro
131141
return devices, nil
132142
}
133143

134-
func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) {
144+
func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
135145
var err error
136-
devID := device.Identifier(id)
137146

138-
if devID.IsUUID() {
139-
return l.nvmllib.DeviceGetHandleByUUID(id)
147+
if id.IsUUID() {
148+
return l.nvmllib.DeviceGetHandleByUUID(string(id))
140149
}
141150

142-
if devID.IsGpuIndex() {
143-
if idx, err := strconv.Atoi(id); err == nil {
151+
if id.IsGpuIndex() {
152+
if idx, err := strconv.Atoi(string(id)); err == nil {
144153
return l.nvmllib.DeviceGetHandleByIndex(idx)
145154
}
146155
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
147156
}
148157

149-
if devID.IsMigIndex() {
158+
if id.IsMigIndex() {
150159
var gpuIdx, migIdx int
151160
var parent nvml.Device
152-
split := strings.SplitN(id, ":", 2)
161+
split := strings.SplitN(string(id), ":", 2)
153162
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
154163
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
155164
}

0 commit comments

Comments
 (0)