diff --git a/internal/plugin/server.go b/internal/plugin/server.go index a9d4c8868..48e174c1f 100644 --- a/internal/plugin/server.go +++ b/internal/plugin/server.go @@ -61,8 +61,9 @@ type nvidiaDevicePlugin struct { socket string server *grpc.Server - health chan *rm.Device - stop chan interface{} + + // Health monitoring + healthProvider rm.HealthProvider imexChannels imex.Channels @@ -90,11 +91,11 @@ func (o *options) devicePluginForResource(ctx context.Context, resourceManager r mps: mpsOptions, socket: getPluginSocketPath(resourceManager.Resource()), - // These will be reinitialized every - // time the plugin server is restarted. + + healthProvider: resourceManager.HealthProvider(ctx), + // server will be reinitialized every time the plugin server is + // restarted. server: nil, - health: nil, - stop: nil, } return &plugin, nil } @@ -108,15 +109,10 @@ func getPluginSocketPath(resource spec.ResourceName) string { func (plugin *nvidiaDevicePlugin) initialize() { plugin.server = grpc.NewServer([]grpc.ServerOption{}...) - plugin.health = make(chan *rm.Device) - plugin.stop = make(chan interface{}) } func (plugin *nvidiaDevicePlugin) cleanup() { - close(plugin.stop) plugin.server = nil - plugin.health = nil - plugin.stop = nil } // Devices returns the full set of devices associated with the plugin. @@ -148,13 +144,10 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error { } klog.Infof("Registered device plugin for '%s' with Kubelet", plugin.rm.Resource()) - go func() { - // TODO: add MPS health check - err := plugin.rm.CheckHealth(plugin.stop, plugin.health) - if err != nil { - klog.Errorf("Failed to start health check: %v; continuing with health checks disabled", err) - } - }() + // TODO: add MPS health check + if err := plugin.healthProvider.Start(plugin.ctx); err != nil { + klog.Errorf("Failed to start health provider: %v; continuing with health checks disabled", err) + } return nil } @@ -164,6 +157,10 @@ func (plugin *nvidiaDevicePlugin) Stop() error { if plugin == nil || plugin.server == nil { return nil } + + // Stop health monitoring + plugin.healthProvider.Stop() + klog.Infof("Stopping to serve '%s' on %s", plugin.rm.Resource(), plugin.socket) plugin.server.Stop() if err := os.Remove(plugin.socket); err != nil && !os.IsNotExist(err) { @@ -271,11 +268,11 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D for { select { - case <-plugin.stop: + case <-plugin.ctx.Done(): return nil - case d := <-plugin.health: - // FIXME: there is no way to recover from the Unhealthy state. - d.Health = pluginapi.Unhealthy + case d := <-plugin.healthProvider.Health(): + // Device became unhealthy + // Device.Health already set to Unhealthy by health provider klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID) if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil { return nil diff --git a/internal/rm/health.go b/internal/rm/health.go index 1f0fc5c41..5d1f0f40d 100644 --- a/internal/rm/health.go +++ b/internal/rm/health.go @@ -17,13 +17,18 @@ package rm import ( + "context" "fmt" "os" "strconv" "strings" + "sync" "github.com/NVIDIA/go-nvml/pkg/nvml" "k8s.io/klog/v2" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" + + spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1" ) const ( @@ -40,22 +45,73 @@ const ( envEnableHealthChecks = "DP_ENABLE_HEALTHCHECKS" ) -// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices -func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devices, unhealthy chan<- *Device) error { +// HealthProvider manages GPU device health monitoring with lifecycle +// control. +type HealthProvider interface { + // Start initiates health monitoring. Blocks until initial setup + // completes. Returns error if health monitoring cannot be started. + Start(context.Context) error + + // Stop gracefully shuts down health monitoring and waits for all + // goroutines to complete. Safe to call multiple times. + Stop() + + // Health returns a read-only channel that receives devices that + // have become unhealthy. + Health() <-chan *Device +} + +// nvmlHealthProvider implements HealthProvider using NVML event +// monitoring. This is a refactoring of the existing checkHealth logic +// with proper lifecycle management. +type nvmlHealthProvider struct { + // Configuration + nvml nvml.Interface + config *spec.Config + devices Devices + + // NVML resources + eventSet nvml.EventSet + + // Lifecycle management + ctx context.Context + wg sync.WaitGroup + + // State guards + sync.Mutex + started bool + stopped bool + + // Communication + healthChan chan *Device + + // XID filtering + xidsDisabled disabledXIDs + + // Device placement maps (for MIG support) + parentToDeviceMap map[string]*Device + deviceIDToGiMap map[string]uint32 + deviceIDToCiMap map[string]uint32 +} + +// newNVMLHealthProvider creates a new health provider for NVML devices. +// Does not start monitoring - caller must call Start(). +func newNVMLHealthProvider(ctx context.Context, nvmllib nvml.Interface, config *spec.Config, devices Devices) (HealthProvider, error) { xids := getDisabledHealthCheckXids() if xids.IsAllDisabled() { - return nil + return &noopHealthProvider{}, nil } - ret := r.nvml.Init() + ret := nvmllib.Init() if ret != nvml.SUCCESS { - if *r.config.Flags.FailOnInitError { - return fmt.Errorf("failed to initialize NVML: %v", ret) + if *config.Flags.FailOnInitError { + return nil, fmt.Errorf("failed to initialize NVML: %v", ret) } - return nil + klog.Warningf("NVML init failed: %v; health checks disabled", ret) + return &noopHealthProvider{}, nil } defer func() { - ret := r.nvml.Shutdown() + ret := nvmllib.Shutdown() if ret != nvml.SUCCESS { klog.Infof("Error shutting down NVML: %v", ret) } @@ -63,143 +119,312 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic klog.Infof("Ignoring the following XIDs for health checks: %v", xids) + p := &nvmlHealthProvider{ + ctx: ctx, + nvml: nvmllib, + config: config, + devices: devices, + healthChan: make(chan *Device, 64), + xidsDisabled: xids, + } + return p, nil +} + +// Start initializes NVML, registers event handlers, and starts the +// monitoring goroutine. Blocks until initialization completes. +func (r *nvmlHealthProvider) Start(ctx context.Context) (rerr error) { + r.Lock() + defer r.Unlock() + if r.started { + // TODO: Is this an error condition? Could we just return? + return fmt.Errorf("health provider already started") + } + r.Unlock() + + // Initialize NVML + ret := r.nvml.Init() + if ret != nvml.SUCCESS { + return fmt.Errorf("failed to initialize NVML: %v", ret) + } + defer func() { + if rerr != nil { + _ = r.nvml.Shutdown() + } + }() + + // Create event set eventSet, ret := r.nvml.EventSetCreate() if ret != nvml.SUCCESS { return fmt.Errorf("failed to create event set: %v", ret) } defer func() { - _ = eventSet.Free() + if rerr != nil { + _ = eventSet.Free() + } }() + r.eventSet = eventSet - parentToDeviceMap := make(map[string]*Device) - deviceIDToGiMap := make(map[string]uint32) - deviceIDToCiMap := make(map[string]uint32) + // Register devices + if err := r.registerDevices(); err != nil { + return fmt.Errorf("failed to register devices: %w", err) + } - eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError) - for _, d := range devices { - uuid, gi, ci, err := r.getDevicePlacement(d) - if err != nil { - klog.Warningf("Could not determine device placement for %v: %v; Marking it unhealthy.", d.ID, err) - unhealthy <- d - continue - } - deviceIDToGiMap[d.ID] = gi - deviceIDToCiMap[d.ID] = ci - parentToDeviceMap[uuid] = d + klog.Infof("Health monitoring started for %d devices", len(r.devices)) - gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid) - if ret != nvml.SUCCESS { - klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret) - unhealthy <- d - continue - } + // Start monitoring goroutine + r.wg.Add(1) + go r.runEventMonitor() - supportedEvents, ret := gpu.GetSupportedEventTypes() - if ret != nvml.SUCCESS { - klog.Infof("unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret) - unhealthy <- d - continue - } + r.started = true - ret = gpu.RegisterEvents(eventMask&supportedEvents, eventSet) - if ret == nvml.ERROR_NOT_SUPPORTED { - klog.Warningf("Device %v is too old to support healthchecking.", d.ID) - } + return nil +} + +// Stop gracefully shuts down health monitoring and waits for the +// monitoring goroutine to complete. +func (r *nvmlHealthProvider) Stop() { + r.Lock() + defer r.Unlock() + + if r.stopped { + return + } + + if !r.started { + r.stopped = true + return + } + + klog.V(2).Info("Stopping health provider...") + + // Wait for goroutine to finish (unlock during wait) + // Goroutine will exit when parent context is cancelled + r.Unlock() + r.wg.Wait() + r.Lock() + + // Cleanup NVML resources + r.cleanup() + + // Close channel + close(r.healthChan) + + r.stopped = true + + klog.Info("Health provider stopped") +} + +// Health returns a read-only channel that receives devices that have +// become unhealthy. +func (r *nvmlHealthProvider) Health() <-chan *Device { + return r.healthChan +} + +// cleanup releases NVML resources. +func (r *nvmlHealthProvider) cleanup() { + if r.eventSet != nil { + ret := r.eventSet.Free() if ret != nvml.SUCCESS { - klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret) - unhealthy <- d + klog.Warningf("Failed to free event set: %v", ret) } + r.eventSet = nil + } + + if ret := r.nvml.Shutdown(); ret != nvml.SUCCESS { + klog.Warningf("NVML shutdown failed: %v", ret) } +} + +// runEventMonitor monitors NVML events and reports unhealthy devices. +// This is the existing checkHealth logic refactored into a goroutine. +func (r *nvmlHealthProvider) runEventMonitor() { + defer r.wg.Done() + + klog.V(2).Info("Health check: event monitor started") + defer klog.V(2).Info("Health check: event monitor stopped") for { + // Check for context cancellation select { - case <-stop: - return nil + case <-r.ctx.Done(): + return default: } - e, ret := eventSet.Wait(5000) + // Wait for NVML event (5 second timeout) + event, ret := r.eventSet.Wait(5000) if ret == nvml.ERROR_TIMEOUT { continue } + if ret != nvml.SUCCESS { - klog.Infof("Error waiting for event: %v; Marking all devices as unhealthy", ret) - for _, d := range devices { - unhealthy <- d + klog.Infof("Error waiting for event: %v; marking all "+ + "devices as unhealthy", ret) + for _, device := range r.devices { + r.sendUnhealthy(device) } continue } - if e.EventType != nvml.EventTypeXidCriticalError { - klog.Infof("Skipping non-nvmlEventTypeXidCriticalError event: %+v", e) + // Only process XID critical errors + if event.EventType != nvml.EventTypeXidCriticalError { + klog.Infof("Skipping non-nvmlEventTypeXidCriticalError "+ + "event: %+v", event) continue } - if xids.IsDisabled(e.EventData) { - klog.Infof("Skipping event %+v", e) + // Check if XID is disabled + if r.xidsDisabled.IsDisabled(event.EventData) { + klog.Infof("Skipping event %+v", event) continue } - klog.Infof("Processing event %+v", e) - eventUUID, ret := e.Device.GetUUID() + klog.Infof("Processing event %+v", event) + + // Find device for event + eventUUID, ret := event.Device.GetUUID() if ret != nvml.SUCCESS { - // If we cannot reliably determine the device UUID, we mark all devices as unhealthy. - klog.Infof("Failed to determine uuid for event %v: %v; Marking all devices as unhealthy.", e, ret) - for _, d := range devices { - unhealthy <- d + klog.Infof("Failed to determine uuid for event %v: %v; "+ + "marking all devices as unhealthy.", event, ret) + for _, device := range r.devices { + r.sendUnhealthy(device) } continue } - d, exists := parentToDeviceMap[eventUUID] + device, exists := r.parentToDeviceMap[eventUUID] if !exists { - klog.Infof("Ignoring event for unexpected device: %v", eventUUID) + klog.Infof("Ignoring event for unexpected device: %v", + eventUUID) continue } - if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF { - gi := deviceIDToGiMap[d.ID] - ci := deviceIDToCiMap[d.ID] - if gi != e.GpuInstanceId || ci != e.ComputeInstanceId { + // Handle MIG devices + if device.IsMigDevice() && + event.GpuInstanceId != 0xFFFFFFFF && + event.ComputeInstanceId != 0xFFFFFFFF { + gi := r.deviceIDToGiMap[device.ID] + ci := r.deviceIDToCiMap[device.ID] + + if gi != event.GpuInstanceId || ci != event.ComputeInstanceId { continue } - klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci) + + klog.Infof("Event for mig device %v (gi=%v, ci=%v)", + device.ID, gi, ci) } - klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy.", e.EventData, d.ID) - unhealthy <- d + // Mark device unhealthy + klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking "+ + "device as unhealthy.", event.EventData, device.ID) + + device.Health = pluginapi.Unhealthy + r.sendUnhealthy(device) } } +// sendUnhealthy sends device to unhealthy channel (non-blocking). +func (r *nvmlHealthProvider) sendUnhealthy(device *Device) { + select { + case r.healthChan <- device: + // Sent successfully + default: + // Channel full + klog.Errorf("Health channel full! Device %s update dropped. "+ + "ListAndWatch may be stalled.", device.ID) + // Device.Health already set to Unhealthy + } +} + +// registerDevices registers all devices with the NVML event set. +// This is the existing logic from checkHealth(). +func (r *nvmlHealthProvider) registerDevices() error { + r.parentToDeviceMap = make(map[string]*Device) + r.deviceIDToGiMap = make(map[string]uint32) + r.deviceIDToCiMap = make(map[string]uint32) + + eventMask := uint64(nvml.EventTypeXidCriticalError | + nvml.EventTypeDoubleBitEccError | + nvml.EventTypeSingleBitEccError) + + for _, device := range r.devices { + uuid, gi, ci, err := r.getDevicePlacement(device) + if err != nil { + klog.Warningf("Could not determine device placement for "+ + "%v: %v; marking it unhealthy.", device.ID, err) + device.Health = pluginapi.Unhealthy + r.sendUnhealthy(device) + continue + } + + r.deviceIDToGiMap[device.ID] = gi + r.deviceIDToCiMap[device.ID] = ci + r.parentToDeviceMap[uuid] = device + + gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + klog.Infof("unable to get device handle from UUID: %v; "+ + "marking it as unhealthy", ret) + device.Health = pluginapi.Unhealthy + r.sendUnhealthy(device) + continue + } + + supportedEvents, ret := gpu.GetSupportedEventTypes() + if ret != nvml.SUCCESS { + klog.Infof("unable to determine the supported events for "+ + "%v: %v; marking it as unhealthy", device.ID, ret) + device.Health = pluginapi.Unhealthy + r.sendUnhealthy(device) + continue + } + + ret = gpu.RegisterEvents(eventMask&supportedEvents, r.eventSet) + if ret == nvml.ERROR_NOT_SUPPORTED { + klog.Warningf("Device %v is too old to support "+ + "healthchecking.", device.ID) + } + if ret != nvml.SUCCESS { + klog.Infof("Marking device %v as unhealthy: %v", + device.ID, ret) + device.Health = pluginapi.Unhealthy + r.sendUnhealthy(device) + } + } + + return nil +} + const allXIDs = 0 // disabledXIDs stores a map of explicitly disabled XIDs. -// The special XID `allXIDs` indicates that all XIDs are disabled, but does -// allow for specific XIDs to be enabled even if this is the case. +// The special XID `allXIDs` indicates that all XIDs are disabled, but +// does allow for specific XIDs to be enabled even if this is the case. type disabledXIDs map[uint64]bool -// Disabled returns whether XID-based health checks are disabled. -// These are considered if all XIDs have been disabled AND no other XIDs have -// been explcitly enabled. +// IsAllDisabled returns whether XID-based health checks are disabled. +// These are considered if all XIDs have been disabled AND no other XIDs +// have been explcitly enabled. func (h disabledXIDs) IsAllDisabled() bool { if allDisabled, ok := h[allXIDs]; ok { return allDisabled } - // At this point we wither have explicitly disabled XIDs or explicitly - // enabled XIDs. Since ANY XID that's not specified is assumed enabled, we - // return here. + // At this point we wither have explicitly disabled XIDs or + // explicitly enabled XIDs. Since ANY XID that's not specified is + // assumed enabled, we return here. return false } -// IsDisabled checks whether the specified XID has been explicitly disalbled. -// An XID is considered disabled if it has been explicitly disabled, or all XIDs -// have been disabled. +// IsDisabled checks whether the specified XID has been explicitly +// disalbled. An XID is considered disabled if it has been explicitly +// disabled, or all XIDs have been disabled. func (h disabledXIDs) IsDisabled(xid uint64) bool { // Handle the case where enabled=all. if explicitAll, ok := h[allXIDs]; ok && !explicitAll { return false } - // Handle the case where the XID has been specifically enabled (or disabled) + // Handle the case where the XID has been specifically enabled (or + // disabled) if disabled, ok := h[xid]; ok { return disabled } @@ -212,17 +437,17 @@ func (h disabledXIDs) IsDisabled(xid uint64) bool { // * A list of hardcoded disabled XIDs // * A list of explicitly enabled XIDs (including all XIDs) // -// Note that if an XID is explicitly enabled, this takes precedence over it -// having been disabled either explicitly or implicitly. +// Note that if an XID is explicitly enabled, this takes precedence over +// it having been disabled either explicitly or implicitly. func getDisabledHealthCheckXids() disabledXIDs { disabled := newHealthCheckXIDs( - // TODO: We should not read the envvar here directly, but instead - // "upgrade" this to a top-level config option. + // TODO: We should not read the envvar here directly, but + // instead "upgrade" this to a top-level config option. strings.Split(strings.ToLower(os.Getenv(envDisableHealthChecks)), ",")..., ) enabled := newHealthCheckXIDs( - // TODO: We should not read the envvar here directly, but instead - // "upgrade" this to a top-level config option. + // TODO: We should not read the envvar here directly, but + // instead "upgrade" this to a top-level config option. strings.Split(strings.ToLower(os.Getenv(envEnableHealthChecks)), ",")..., ) @@ -250,16 +475,16 @@ func getDisabledHealthCheckXids() disabledXIDs { } // newHealthCheckXIDs converts a list of Xids to a healthCheckXIDs map. -// Special xid values 'all' and 'xids' return a special map that matches all -// xids. -// For other xids, these are converted to a uint64 values with invalid values -// being ignored. +// Special xid values 'all' and 'xids' return a special map that matches +// all xids. For other xids, these are converted to a uint64 values with +// invalid values being ignored. func newHealthCheckXIDs(xids ...string) disabledXIDs { output := make(disabledXIDs) for _, xid := range xids { trimmed := strings.TrimSpace(xid) if trimmed == "all" || trimmed == "xids" { - // TODO: We should have a different type for "all" and "all-except" + // TODO: We should have a different type for "all" and + // "all-except" return disabledXIDs{allXIDs: true} } if trimmed == "" { @@ -267,7 +492,8 @@ func newHealthCheckXIDs(xids ...string) disabledXIDs { } id, err := strconv.ParseUint(trimmed, 10, 64) if err != nil { - klog.Infof("Ignoring malformed Xid value %v: %v", trimmed, err) + klog.Infof("Ignoring malformed Xid value %v: %v", + trimmed, err) continue } @@ -277,9 +503,10 @@ func newHealthCheckXIDs(xids ...string) disabledXIDs { } // getDevicePlacement returns the placement of the specified device. -// For a MIG device the placement is defined by the 3-tuple -// For a full device the returned 3-tuple is the device's uuid and 0xFFFFFFFF for the other two elements. -func (r *nvmlResourceManager) getDevicePlacement(d *Device) (string, uint32, uint32, error) { +// For a MIG device the placement is defined by the 3-tuple +// . For a full device the returned 3-tuple is the +// device's uuid and 0xFFFFFFFF for the other two elements. +func (r *nvmlHealthProvider) getDevicePlacement(d *Device) (string, uint32, uint32, error) { if !d.IsMigDevice() { return d.GetUUID(), 0xFFFFFFFF, 0xFFFFFFFF, nil } @@ -287,40 +514,47 @@ func (r *nvmlResourceManager) getDevicePlacement(d *Device) (string, uint32, uin } // getMigDeviceParts returns the parent GI and CI ids of the MIG device. -func (r *nvmlResourceManager) getMigDeviceParts(d *Device) (string, uint32, uint32, error) { +func (r *nvmlHealthProvider) getMigDeviceParts(d *Device) (string, uint32, uint32, error) { if !d.IsMigDevice() { return "", 0, 0, fmt.Errorf("cannot get GI and CI of full device") } uuid := d.GetUUID() - // For older driver versions, the call to DeviceGetHandleByUUID will fail for MIG devices. + // For older driver versions, the call to DeviceGetHandleByUUID will + // fail for MIG devices. mig, ret := r.nvml.DeviceGetHandleByUUID(uuid) if ret == nvml.SUCCESS { parentHandle, ret := mig.GetDeviceHandleFromMigDeviceHandle() if ret != nvml.SUCCESS { - return "", 0, 0, fmt.Errorf("failed to get parent device handle: %v", ret) + return "", 0, 0, fmt.Errorf("failed to get parent "+ + "device handle: %v", ret) } parentUUID, ret := parentHandle.GetUUID() if ret != nvml.SUCCESS { - return "", 0, 0, fmt.Errorf("failed to get parent uuid: %v", ret) + return "", 0, 0, fmt.Errorf("failed to get parent "+ + "uuid: %v", ret) } gi, ret := mig.GetGpuInstanceId() if ret != nvml.SUCCESS { - return "", 0, 0, fmt.Errorf("failed to get GPU Instance ID: %v", ret) + return "", 0, 0, fmt.Errorf("failed to get GPU "+ + "Instance ID: %v", ret) } ci, ret := mig.GetComputeInstanceId() if ret != nvml.SUCCESS { - return "", 0, 0, fmt.Errorf("failed to get Compute Instance ID: %v", ret) + return "", 0, 0, fmt.Errorf("failed to get Compute "+ + "Instance ID: %v", ret) } - //nolint:gosec // We know that the values returned from Get*InstanceId are within the valid uint32 range. + //nolint:gosec // We know that the values returned from Get*InstanceId + // are within the valid uint32 range. return parentUUID, uint32(gi), uint32(ci), nil } return parseMigDeviceUUID(uuid) } -// parseMigDeviceUUID splits the MIG device UUID into the parent device UUID and ci and gi +// parseMigDeviceUUID splits the MIG device UUID into the parent device +// UUID and ci and gi func parseMigDeviceUUID(mig string) (string, uint32, uint32, error) { tokens := strings.SplitN(mig, "-", 2) if len(tokens) != 2 || tokens[0] != "MIG" { @@ -353,3 +587,24 @@ func toUint32(s string) (uint32, error) { //nolint:gosec // Since we parse s with a 32-bit size this will not overflow. return uint32(u), nil } + +// noopHealthProvider is a no-op implementation for platforms or +// configurations that don't support health monitoring. +type noopHealthProvider struct { + healthChan chan *Device +} + +func (n *noopHealthProvider) Start(context.Context) error { + n.healthChan = make(chan *Device) + return nil +} + +func (n *noopHealthProvider) Stop() { + if n.healthChan != nil { + close(n.healthChan) + } +} + +func (n *noopHealthProvider) Health() <-chan *Device { + return n.healthChan +} diff --git a/internal/rm/nvml_manager.go b/internal/rm/nvml_manager.go index fac923429..c398edc69 100644 --- a/internal/rm/nvml_manager.go +++ b/internal/rm/nvml_manager.go @@ -17,6 +17,7 @@ package rm import ( + "context" "fmt" "github.com/NVIDIA/go-gpuallocator/gpuallocator" @@ -90,9 +91,20 @@ func (r *nvmlResourceManager) GetDevicePaths(ids []string) []string { return append(paths, r.Devices().Subset(ids).GetPaths()...) } -// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices -func (r *nvmlResourceManager) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error { - return r.checkHealth(stop, r.devices, unhealthy) +// HealthProvider returns a HealthProvider for NVML device health +// monitoring. Returns a no-op provider if health checks are disabled. +func (r *nvmlResourceManager) HealthProvider(ctx context.Context) HealthProvider { + xids := getDisabledHealthCheckXids() + if xids.IsAllDisabled() { + klog.Info("Health checks disabled via DP_DISABLE_HEALTHCHECKS") + return &noopHealthProvider{} + } + p, err := newNVMLHealthProvider(ctx, r.nvml, r.config, r.devices) + if err != nil { + klog.Errorf("Failed to create NVML health provider: %v", err) + return &noopHealthProvider{} + } + return p } // getPreferredAllocation runs an allocation algorithm over the inputs. diff --git a/internal/rm/rm.go b/internal/rm/rm.go index 33f44b9d8..7ce1b95b4 100644 --- a/internal/rm/rm.go +++ b/internal/rm/rm.go @@ -17,6 +17,7 @@ package rm import ( + "context" "errors" "fmt" "strings" @@ -36,7 +37,8 @@ type resourceManager struct { devices Devices } -// ResourceManager provides an interface for listing a set of Devices and checking health on them +// ResourceManager provides an interface for listing a set of Devices +// and managing their health. // //go:generate moq -rm -fmt=goimports -stub -out rm_mock.go . ResourceManager type ResourceManager interface { @@ -44,8 +46,12 @@ type ResourceManager interface { Devices() Devices GetDevicePaths([]string) []string GetPreferredAllocation(available, required []string, size int) ([]string, error) - CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error ValidateRequest(AnnotatedIDs) error + + // HealthProvider returns a HealthProvider for monitoring device + // health. The context is used for the lifecycle of the health + // monitoring goroutines. + HealthProvider(ctx context.Context) HealthProvider } // Resource gets the resource name associated with the ResourceManager diff --git a/internal/rm/rm_mock.go b/internal/rm/rm_mock.go index 4efee5fd9..9e9d3eed9 100644 --- a/internal/rm/rm_mock.go +++ b/internal/rm/rm_mock.go @@ -4,6 +4,7 @@ package rm import ( + "context" "sync" spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1" @@ -19,9 +20,6 @@ var _ ResourceManager = &ResourceManagerMock{} // // // make and configure a mocked ResourceManager // mockedResourceManager := &ResourceManagerMock{ -// CheckHealthFunc: func(stop <-chan interface{}, unhealthy chan<- *Device) error { -// panic("mock out the CheckHealth method") -// }, // DevicesFunc: func() Devices { // panic("mock out the Devices method") // }, @@ -31,6 +29,9 @@ var _ ResourceManager = &ResourceManagerMock{} // GetPreferredAllocationFunc: func(available []string, required []string, size int) ([]string, error) { // panic("mock out the GetPreferredAllocation method") // }, +// HealthProviderFunc: func(ctx context.Context) HealthProvider { +// panic("mock out the HealthProvider method") +// }, // ResourceFunc: func() spec.ResourceName { // panic("mock out the Resource method") // }, @@ -44,9 +45,6 @@ var _ ResourceManager = &ResourceManagerMock{} // // } type ResourceManagerMock struct { - // CheckHealthFunc mocks the CheckHealth method. - CheckHealthFunc func(stop <-chan interface{}, unhealthy chan<- *Device) error - // DevicesFunc mocks the Devices method. DevicesFunc func() Devices @@ -56,6 +54,9 @@ type ResourceManagerMock struct { // GetPreferredAllocationFunc mocks the GetPreferredAllocation method. GetPreferredAllocationFunc func(available []string, required []string, size int) ([]string, error) + // HealthProviderFunc mocks the HealthProvider method. + HealthProviderFunc func(ctx context.Context) HealthProvider + // ResourceFunc mocks the Resource method. ResourceFunc func() spec.ResourceName @@ -64,13 +65,6 @@ type ResourceManagerMock struct { // calls tracks calls to the methods. calls struct { - // CheckHealth holds details about calls to the CheckHealth method. - CheckHealth []struct { - // Stop is the stop argument value. - Stop <-chan interface{} - // Unhealthy is the unhealthy argument value. - Unhealthy chan<- *Device - } // Devices holds details about calls to the Devices method. Devices []struct { } @@ -88,6 +82,11 @@ type ResourceManagerMock struct { // Size is the size argument value. Size int } + // HealthProvider holds details about calls to the HealthProvider method. + HealthProvider []struct { + // Ctx is the ctx argument value. + Ctx context.Context + } // Resource holds details about calls to the Resource method. Resource []struct { } @@ -97,53 +96,14 @@ type ResourceManagerMock struct { AnnotatedIDs AnnotatedIDs } } - lockCheckHealth sync.RWMutex lockDevices sync.RWMutex lockGetDevicePaths sync.RWMutex lockGetPreferredAllocation sync.RWMutex + lockHealthProvider sync.RWMutex lockResource sync.RWMutex lockValidateRequest sync.RWMutex } -// CheckHealth calls CheckHealthFunc. -func (mock *ResourceManagerMock) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error { - callInfo := struct { - Stop <-chan interface{} - Unhealthy chan<- *Device - }{ - Stop: stop, - Unhealthy: unhealthy, - } - mock.lockCheckHealth.Lock() - mock.calls.CheckHealth = append(mock.calls.CheckHealth, callInfo) - mock.lockCheckHealth.Unlock() - if mock.CheckHealthFunc == nil { - var ( - errOut error - ) - return errOut - } - return mock.CheckHealthFunc(stop, unhealthy) -} - -// CheckHealthCalls gets all the calls that were made to CheckHealth. -// Check the length with: -// -// len(mockedResourceManager.CheckHealthCalls()) -func (mock *ResourceManagerMock) CheckHealthCalls() []struct { - Stop <-chan interface{} - Unhealthy chan<- *Device -} { - var calls []struct { - Stop <-chan interface{} - Unhealthy chan<- *Device - } - mock.lockCheckHealth.RLock() - calls = mock.calls.CheckHealth - mock.lockCheckHealth.RUnlock() - return calls -} - // Devices calls DevicesFunc. func (mock *ResourceManagerMock) Devices() Devices { callInfo := struct { @@ -253,6 +213,41 @@ func (mock *ResourceManagerMock) GetPreferredAllocationCalls() []struct { return calls } +// HealthProvider calls HealthProviderFunc. +func (mock *ResourceManagerMock) HealthProvider(ctx context.Context) HealthProvider { + callInfo := struct { + Ctx context.Context + }{ + Ctx: ctx, + } + mock.lockHealthProvider.Lock() + mock.calls.HealthProvider = append(mock.calls.HealthProvider, callInfo) + mock.lockHealthProvider.Unlock() + if mock.HealthProviderFunc == nil { + var ( + healthProviderOut HealthProvider + ) + return healthProviderOut + } + return mock.HealthProviderFunc(ctx) +} + +// HealthProviderCalls gets all the calls that were made to HealthProvider. +// Check the length with: +// +// len(mockedResourceManager.HealthProviderCalls()) +func (mock *ResourceManagerMock) HealthProviderCalls() []struct { + Ctx context.Context +} { + var calls []struct { + Ctx context.Context + } + mock.lockHealthProvider.RLock() + calls = mock.calls.HealthProvider + mock.lockHealthProvider.RUnlock() + return calls +} + // Resource calls ResourceFunc. func (mock *ResourceManagerMock) Resource() spec.ResourceName { callInfo := struct { diff --git a/internal/rm/tegra_manager.go b/internal/rm/tegra_manager.go index 65ca2022f..3f1e1635a 100644 --- a/internal/rm/tegra_manager.go +++ b/internal/rm/tegra_manager.go @@ -17,6 +17,7 @@ package rm import ( + "context" "fmt" spec "github.com/NVIDIA/k8s-device-plugin/api/config/v1" @@ -70,7 +71,8 @@ func (r *tegraResourceManager) GetDevicePaths(ids []string) []string { return nil } -// CheckHealth is disabled for the tegraResourceManager -func (r *tegraResourceManager) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error { - return nil +// HealthProvider returns a no-op HealthProvider for Tegra devices. +// Tegra devices do not support health monitoring. +func (r *tegraResourceManager) HealthProvider(ctx context.Context) HealthProvider { + return &noopHealthProvider{} }