|
17 | 17 | package nvcdi |
18 | 18 |
|
19 | 19 | import ( |
| 20 | + "errors" |
20 | 21 | "fmt" |
21 | 22 |
|
22 | | - "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" |
23 | 23 | "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" |
24 | 24 | ) |
25 | 25 |
|
| 26 | +// UUIDer is an interface for getting UUIDs. |
| 27 | +type UUIDer interface { |
| 28 | + GetUUID() (string, error) |
| 29 | +} |
| 30 | + |
26 | 31 | // DeviceNamer is an interface for getting device names |
27 | 32 | type DeviceNamer interface { |
28 | | - GetDeviceName(int, device.Device) (string, error) |
29 | | - GetMigDeviceName(int, device.Device, int, device.MigDevice) (string, error) |
| 33 | + GetDeviceName(int, UUIDer) (string, error) |
| 34 | + GetMigDeviceName(int, UUIDer, int, UUIDer) (string, error) |
30 | 35 | } |
31 | 36 |
|
32 | 37 | // Supported device naming strategies |
@@ -61,29 +66,57 @@ func NewDeviceNamer(strategy string) (DeviceNamer, error) { |
61 | 66 | } |
62 | 67 |
|
63 | 68 | // GetDeviceName returns the name for the specified device based on the naming strategy |
64 | | -func (s deviceNameIndex) GetDeviceName(i int, d device.Device) (string, error) { |
| 69 | +func (s deviceNameIndex) GetDeviceName(i int, _ UUIDer) (string, error) { |
65 | 70 | return fmt.Sprintf("%s%d", s.gpuPrefix, i), nil |
66 | 71 | } |
67 | 72 |
|
68 | 73 | // GetMigDeviceName returns the name for the specified device based on the naming strategy |
69 | | -func (s deviceNameIndex) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) { |
| 74 | +func (s deviceNameIndex) GetMigDeviceName(i int, _ UUIDer, j int, _ UUIDer) (string, error) { |
70 | 75 | return fmt.Sprintf("%s%d:%d", s.migPrefix, i, j), nil |
71 | 76 | } |
72 | 77 |
|
73 | 78 | // GetDeviceName returns the name for the specified device based on the naming strategy |
74 | | -func (s deviceNameUUID) GetDeviceName(i int, d device.Device) (string, error) { |
75 | | - uuid, ret := d.GetUUID() |
76 | | - if ret != nvml.SUCCESS { |
77 | | - return "", fmt.Errorf("failed to get device UUID: %v", ret) |
| 79 | +func (s deviceNameUUID) GetDeviceName(i int, d UUIDer) (string, error) { |
| 80 | + uuid, err := d.GetUUID() |
| 81 | + if err != nil { |
| 82 | + return "", fmt.Errorf("failed to get device UUID: %v", err) |
78 | 83 | } |
79 | 84 | return uuid, nil |
80 | 85 | } |
81 | 86 |
|
82 | 87 | // GetMigDeviceName returns the name for the specified device based on the naming strategy |
83 | | -func (s deviceNameUUID) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) { |
84 | | - uuid, ret := mig.GetUUID() |
| 88 | +func (s deviceNameUUID) GetMigDeviceName(i int, _ UUIDer, j int, mig UUIDer) (string, error) { |
| 89 | + uuid, err := mig.GetUUID() |
| 90 | + if err != nil { |
| 91 | + return "", fmt.Errorf("failed to get device UUID: %v", err) |
| 92 | + } |
| 93 | + return uuid, nil |
| 94 | +} |
| 95 | + |
| 96 | +//go:generate moq -stub -out namer_nvml_mock.go . nvmlUUIDer |
| 97 | +type nvmlUUIDer interface { |
| 98 | + GetUUID() (string, nvml.Return) |
| 99 | +} |
| 100 | + |
| 101 | +type convert struct { |
| 102 | + nvmlUUIDer |
| 103 | +} |
| 104 | + |
| 105 | +type uuidUnsupported struct{} |
| 106 | + |
| 107 | +func (m convert) GetUUID() (string, error) { |
| 108 | + if m.nvmlUUIDer == nil { |
| 109 | + return uuidUnsupported{}.GetUUID() |
| 110 | + } |
| 111 | + uuid, ret := m.nvmlUUIDer.GetUUID() |
85 | 112 | if ret != nvml.SUCCESS { |
86 | | - return "", fmt.Errorf("failed to get device UUID: %v", ret) |
| 113 | + return "", ret |
87 | 114 | } |
88 | 115 | return uuid, nil |
89 | 116 | } |
| 117 | + |
| 118 | +var errUUIDUnsupported = errors.New("GetUUID is not supported") |
| 119 | + |
| 120 | +func (m uuidUnsupported) GetUUID() (string, error) { |
| 121 | + return "", errUUIDUnsupported |
| 122 | +} |
0 commit comments