Skip to content

Commit 91f4a6c

Browse files
[no-relnote] refactor use p as receiver for nvmlHealthProvider
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
1 parent 25a3384 commit 91f4a6c

File tree

1 file changed

+125
-68
lines changed

1 file changed

+125
-68
lines changed

internal/rm/health.go

Lines changed: 125 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -114,69 +114,69 @@ func NewNVMLHealthProvider(
114114

115115
// Start initializes NVML, registers event handlers, and starts the
116116
// monitoring goroutine. Blocks until initialization completes.
117-
func (r *nvmlHealthProvider) Start(ctx context.Context) error {
118-
r.Lock()
119-
if r.started {
120-
r.Unlock()
117+
func (p *nvmlHealthProvider) Start(ctx context.Context) error {
118+
p.Lock()
119+
if p.started {
120+
p.Unlock()
121121
return fmt.Errorf("health provider already started")
122122
}
123-
r.started = true
124-
r.Unlock()
123+
p.started = true
124+
p.Unlock()
125125

126126
// Get XID filter configuration
127-
r.xidsDisabled = getDisabledHealthCheckXids()
127+
p.xidsDisabled = getDisabledHealthCheckXids()
128128

129129
// Initialize NVML
130-
ret := r.nvml.Init()
130+
ret := p.nvml.Init()
131131
if ret != nvml.SUCCESS {
132-
if *r.config.Flags.FailOnInitError {
132+
if *p.config.Flags.FailOnInitError {
133133
return fmt.Errorf("failed to initialize NVML: %v", ret)
134134
}
135135
klog.Warningf("NVML init failed: %v; health checks disabled", ret)
136136
return nil
137137
}
138138

139139
// Create event set
140-
eventSet, ret := r.nvml.EventSetCreate()
140+
eventSet, ret := p.nvml.EventSetCreate()
141141
if ret != nvml.SUCCESS {
142-
if shutdownRet := r.nvml.Shutdown(); shutdownRet != nvml.SUCCESS {
142+
if shutdownRet := p.nvml.Shutdown(); shutdownRet != nvml.SUCCESS {
143143
klog.Warningf("Failed to shutdown NVML: %v", shutdownRet)
144144
}
145145
return fmt.Errorf("failed to create event set: %v", ret)
146146
}
147-
r.eventSet = eventSet
147+
p.eventSet = eventSet
148148

149149
// Register devices
150-
if err := r.registerDevices(); err != nil {
151-
r.cleanup()
150+
if err := p.registerDevices(); err != nil {
151+
p.cleanup()
152152
return fmt.Errorf("failed to register devices: %w", err)
153153
}
154154

155-
klog.Infof("Health monitoring started for %d devices", len(r.devices))
155+
klog.Infof("Health monitoring started for %d devices", len(p.devices))
156156
klog.Infof("Ignoring the following XIDs for health checks: %v",
157-
r.xidsDisabled)
157+
p.xidsDisabled)
158158

159159
// Create child context
160-
r.ctx, r.cancel = context.WithCancel(ctx)
160+
p.ctx, p.cancel = context.WithCancel(ctx)
161161

162162
// Start monitoring goroutine
163-
r.wg.Add(1)
164-
go r.runEventMonitor()
163+
p.wg.Add(1)
164+
go p.runEventMonitor()
165165

166166
return nil
167167
}
168168

169169
// Stop gracefully shuts down health monitoring and waits for the
170170
// monitoring goroutine to complete.
171-
func (r *nvmlHealthProvider) Stop() {
172-
r.Lock()
173-
if r.stopped {
174-
r.Unlock()
171+
func (p *nvmlHealthProvider) Stop() {
172+
p.Lock()
173+
if p.stopped {
174+
p.Unlock()
175175
return
176176
}
177-
r.stopped = true
178-
alreadyStarted := r.started
179-
r.Unlock()
177+
p.stopped = true
178+
alreadyStarted := p.started
179+
p.Unlock()
180180

181181
if !alreadyStarted {
182182
return
@@ -185,61 +185,61 @@ func (r *nvmlHealthProvider) Stop() {
185185
klog.V(2).Info("Stopping health provider...")
186186

187187
// Signal goroutine to stop
188-
if r.cancel != nil {
189-
r.cancel()
188+
if p.cancel != nil {
189+
p.cancel()
190190
}
191191

192192
// Wait for goroutine to finish
193-
r.wg.Wait()
193+
p.wg.Wait()
194194

195195
// Cleanup NVML resources
196-
r.cleanup()
196+
p.cleanup()
197197

198198
// Close channel
199-
close(r.healthChan)
199+
close(p.healthChan)
200200

201201
klog.Info("Health provider stopped")
202202
}
203203

204204
// Health returns a read-only channel that receives devices that have
205205
// become unhealthy.
206-
func (r *nvmlHealthProvider) Health() <-chan *Device {
207-
return r.healthChan
206+
func (p *nvmlHealthProvider) Health() <-chan *Device {
207+
return p.healthChan
208208
}
209209

210210
// cleanup releases NVML resources.
211-
func (r *nvmlHealthProvider) cleanup() {
212-
if r.eventSet != nil {
213-
ret := r.eventSet.Free()
211+
func (p *nvmlHealthProvider) cleanup() {
212+
if p.eventSet != nil {
213+
ret := p.eventSet.Free()
214214
if ret != nvml.SUCCESS {
215215
klog.Warningf("Failed to free event set: %v", ret)
216216
}
217-
r.eventSet = nil
217+
p.eventSet = nil
218218
}
219219

220-
if ret := r.nvml.Shutdown(); ret != nvml.SUCCESS {
220+
if ret := p.nvml.Shutdown(); ret != nvml.SUCCESS {
221221
klog.Warningf("NVML shutdown failed: %v", ret)
222222
}
223223
}
224224

225225
// runEventMonitor monitors NVML events and reports unhealthy devices.
226226
// This is the existing checkHealth logic refactored into a goroutine.
227-
func (r *nvmlHealthProvider) runEventMonitor() {
228-
defer r.wg.Done()
227+
func (p *nvmlHealthProvider) runEventMonitor() {
228+
defer p.wg.Done()
229229

230230
klog.V(2).Info("Health check: event monitor started")
231231
defer klog.V(2).Info("Health check: event monitor stopped")
232232

233233
for {
234234
// Check for context cancellation
235235
select {
236-
case <-r.ctx.Done():
236+
case <-p.ctx.Done():
237237
return
238238
default:
239239
}
240240

241241
// Wait for NVML event (5 second timeout)
242-
event, ret := r.eventSet.Wait(5000)
242+
event, ret := p.eventSet.Wait(5000)
243243

244244
if ret == nvml.ERROR_TIMEOUT {
245245
continue
@@ -248,8 +248,8 @@ func (r *nvmlHealthProvider) runEventMonitor() {
248248
if ret != nvml.SUCCESS {
249249
klog.Infof("Error waiting for event: %v; marking all "+
250250
"devices as unhealthy", ret)
251-
for _, device := range r.devices {
252-
r.sendUnhealthy(device)
251+
for _, device := range p.devices {
252+
p.sendUnhealthy(device)
253253
}
254254
continue
255255
}
@@ -262,7 +262,7 @@ func (r *nvmlHealthProvider) runEventMonitor() {
262262
}
263263

264264
// Check if XID is disabled
265-
if r.xidsDisabled.IsDisabled(event.EventData) {
265+
if p.xidsDisabled.IsDisabled(event.EventData) {
266266
klog.Infof("Skipping event %+v", event)
267267
continue
268268
}
@@ -274,13 +274,13 @@ func (r *nvmlHealthProvider) runEventMonitor() {
274274
if ret != nvml.SUCCESS {
275275
klog.Infof("Failed to determine uuid for event %v: %v; "+
276276
"marking all devices as unhealthy.", event, ret)
277-
for _, device := range r.devices {
278-
r.sendUnhealthy(device)
277+
for _, device := range p.devices {
278+
p.sendUnhealthy(device)
279279
}
280280
continue
281281
}
282282

283-
device, exists := r.parentToDeviceMap[eventUUID]
283+
device, exists := p.parentToDeviceMap[eventUUID]
284284
if !exists {
285285
klog.Infof("Ignoring event for unexpected device: %v",
286286
eventUUID)
@@ -291,8 +291,8 @@ func (r *nvmlHealthProvider) runEventMonitor() {
291291
if device.IsMigDevice() &&
292292
event.GpuInstanceId != 0xFFFFFFFF &&
293293
event.ComputeInstanceId != 0xFFFFFFFF {
294-
gi := r.deviceIDToGiMap[device.ID]
295-
ci := r.deviceIDToCiMap[device.ID]
294+
gi := p.deviceIDToGiMap[device.ID]
295+
ci := p.deviceIDToCiMap[device.ID]
296296

297297
if gi != event.GpuInstanceId || ci != event.ComputeInstanceId {
298298
continue
@@ -307,14 +307,71 @@ func (r *nvmlHealthProvider) runEventMonitor() {
307307
"device as unhealthy.", event.EventData, device.ID)
308308

309309
device.Health = pluginapi.Unhealthy
310-
r.sendUnhealthy(device)
310+
p.sendUnhealthy(device)
311311
}
312312
}
313313

314+
<<<<<<< HEAD
315+
=======
316+
// getDevicePlacement returns the placement of the specified device.
317+
// For a MIG device the placement is defined by the 3-tuple
318+
// <parent UUID, GI, CI>. For a full device the returned 3-tuple is the
319+
// device's uuid and 0xFFFFFFFF for the other two elements.
320+
func (p *nvmlHealthProvider) getDevicePlacement(
321+
d *Device,
322+
) (string, uint32, uint32, error) {
323+
if !d.IsMigDevice() {
324+
return d.GetUUID(), 0xFFFFFFFF, 0xFFFFFFFF, nil
325+
}
326+
return p.getMigDeviceParts(d)
327+
}
328+
329+
// getMigDeviceParts returns the parent GI and CI ids of the MIG device.
330+
func (p *nvmlHealthProvider) getMigDeviceParts(
331+
d *Device,
332+
) (string, uint32, uint32, error) {
333+
if !d.IsMigDevice() {
334+
return "", 0, 0, fmt.Errorf("cannot get GI and CI of full device")
335+
}
336+
337+
uuid := d.GetUUID()
338+
// For older driver versions, the call to DeviceGetHandleByUUID will
339+
// fail for MIG devices.
340+
mig, ret := p.nvml.DeviceGetHandleByUUID(uuid)
341+
if ret == nvml.SUCCESS {
342+
parentHandle, ret := mig.GetDeviceHandleFromMigDeviceHandle()
343+
if ret != nvml.SUCCESS {
344+
return "", 0, 0, fmt.Errorf("failed to get parent "+
345+
"device handle: %v", ret)
346+
}
347+
348+
parentUUID, ret := parentHandle.GetUUID()
349+
if ret != nvml.SUCCESS {
350+
return "", 0, 0, fmt.Errorf("failed to get parent "+
351+
"uuid: %v", ret)
352+
}
353+
gi, ret := mig.GetGpuInstanceId()
354+
if ret != nvml.SUCCESS {
355+
return "", 0, 0, fmt.Errorf("failed to get GPU "+
356+
"Instance ID: %v", ret)
357+
}
358+
359+
ci, ret := mig.GetComputeInstanceId()
360+
if ret != nvml.SUCCESS {
361+
return "", 0, 0, fmt.Errorf("failed to get Compute "+
362+
"Instance ID: %v", ret)
363+
}
364+
//nolint:gosec // We know that the values returned from Get*InstanceId are within the valid uint32 range.
365+
return parentUUID, uint32(gi), uint32(ci), nil
366+
}
367+
return parseMigDeviceUUID(uuid)
368+
}
369+
370+
>>>>>>> 3f8110390 ([no-relnote] refactor use p as receiver for nvmlHealthProvider)
314371
// sendUnhealthy sends device to unhealthy channel (non-blocking).
315-
func (r *nvmlHealthProvider) sendUnhealthy(device *Device) {
372+
func (p *nvmlHealthProvider) sendUnhealthy(device *Device) {
316373
select {
317-
case r.healthChan <- device:
374+
case p.healthChan <- device:
318375
// Sent successfully
319376
default:
320377
// Channel full
@@ -326,35 +383,35 @@ func (r *nvmlHealthProvider) sendUnhealthy(device *Device) {
326383

327384
// registerDevices registers all devices with the NVML event set.
328385
// This is the existing logic from checkHealth().
329-
func (r *nvmlHealthProvider) registerDevices() error {
330-
r.parentToDeviceMap = make(map[string]*Device)
331-
r.deviceIDToGiMap = make(map[string]uint32)
332-
r.deviceIDToCiMap = make(map[string]uint32)
386+
func (p *nvmlHealthProvider) registerDevices() error {
387+
p.parentToDeviceMap = make(map[string]*Device)
388+
p.deviceIDToGiMap = make(map[string]uint32)
389+
p.deviceIDToCiMap = make(map[string]uint32)
333390

334391
eventMask := uint64(nvml.EventTypeXidCriticalError |
335392
nvml.EventTypeDoubleBitEccError |
336393
nvml.EventTypeSingleBitEccError)
337394

338-
for _, device := range r.devices {
339-
uuid, gi, ci, err := r.getDevicePlacement(device)
395+
for _, device := range p.devices {
396+
uuid, gi, ci, err := p.getDevicePlacement(device)
340397
if err != nil {
341398
klog.Warningf("Could not determine device placement for "+
342399
"%v: %v; marking it unhealthy.", device.ID, err)
343400
device.Health = pluginapi.Unhealthy
344-
r.sendUnhealthy(device)
401+
p.sendUnhealthy(device)
345402
continue
346403
}
347404

348-
r.deviceIDToGiMap[device.ID] = gi
349-
r.deviceIDToCiMap[device.ID] = ci
350-
r.parentToDeviceMap[uuid] = device
405+
p.deviceIDToGiMap[device.ID] = gi
406+
p.deviceIDToCiMap[device.ID] = ci
407+
p.parentToDeviceMap[uuid] = device
351408

352-
gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid)
409+
gpu, ret := p.nvml.DeviceGetHandleByUUID(uuid)
353410
if ret != nvml.SUCCESS {
354411
klog.Infof("unable to get device handle from UUID: %v; "+
355412
"marking it as unhealthy", ret)
356413
device.Health = pluginapi.Unhealthy
357-
r.sendUnhealthy(device)
414+
p.sendUnhealthy(device)
358415
continue
359416
}
360417

@@ -363,11 +420,11 @@ func (r *nvmlHealthProvider) registerDevices() error {
363420
klog.Infof("unable to determine the supported events for "+
364421
"%v: %v; marking it as unhealthy", device.ID, ret)
365422
device.Health = pluginapi.Unhealthy
366-
r.sendUnhealthy(device)
423+
p.sendUnhealthy(device)
367424
continue
368425
}
369426

370-
ret = gpu.RegisterEvents(eventMask&supportedEvents, r.eventSet)
427+
ret = gpu.RegisterEvents(eventMask&supportedEvents, p.eventSet)
371428
if ret == nvml.ERROR_NOT_SUPPORTED {
372429
klog.Warningf("Device %v is too old to support "+
373430
"healthchecking.", device.ID)
@@ -376,7 +433,7 @@ func (r *nvmlHealthProvider) registerDevices() error {
376433
klog.Infof("Marking device %v as unhealthy: %v",
377434
device.ID, ret)
378435
device.Health = pluginapi.Unhealthy
379-
r.sendUnhealthy(device)
436+
p.sendUnhealthy(device)
380437
}
381438
}
382439

0 commit comments

Comments
 (0)