From d302c7a97616c977bf20b5edcfe1dc078accb3eb Mon Sep 17 00:00:00 2001 From: dkeven Date: Thu, 22 Jan 2026 16:25:35 +0800 Subject: [PATCH] feat(device-plugin): supports dynamic detection of hot plugged in GPUs --- cmd/scheduler/main.go | 1 + .../nvidiadevice/nvinternal/plugin/util.go | 1 + .../nvidiadevice/nvinternal/rm/health.go | 168 ++++++++++-------- .../nvinternal/rm/nvml_manager.go | 124 ++++++++++++- pkg/scheduler/config/config.go | 4 + pkg/scheduler/scheduler.go | 20 ++- 6 files changed, 237 insertions(+), 81 deletions(-) diff --git a/cmd/scheduler/main.go b/cmd/scheduler/main.go index d69d0ddcc..1af9f4c5d 100644 --- a/cmd/scheduler/main.go +++ b/cmd/scheduler/main.go @@ -75,6 +75,7 @@ func init() { rootCmd.Flags().IntVar(&config.Timeout, "kube-timeout", client.DefaultTimeout, "Timeout to use while talking with kube-apiserver.") rootCmd.Flags().BoolVar(&enableProfiling, "profiling", false, "Enable pprof profiling via HTTP server") rootCmd.Flags().DurationVar(&config.NodeLockTimeout, "node-lock-timeout", time.Minute*5, "timeout for node locks") + rootCmd.Flags().DurationVar(&config.CleanupStartupDelay, "cleanup-startup-delay", 90*time.Second, "delay before starting cleanup loops (CleanupGPUBindingsLoop/CleanupPodsWithMissingDevicesLoop)") rootCmd.Flags().BoolVar(&config.ForceOverwriteDefaultScheduler, "force-overwrite-default-scheduler", true, "Overwrite schedulerName in Pod Spec when set to the const DefaultSchedulerName in https://k8s.io/api/core/v1 package") rootCmd.PersistentFlags().AddGoFlagSet(device.GlobalFlagSet()) diff --git a/pkg/device-plugin/nvidiadevice/nvinternal/plugin/util.go b/pkg/device-plugin/nvidiadevice/nvinternal/plugin/util.go index 90df79312..dbc4e9155 100644 --- a/pkg/device-plugin/nvidiadevice/nvinternal/plugin/util.go +++ b/pkg/device-plugin/nvidiadevice/nvinternal/plugin/util.go @@ -186,6 +186,7 @@ func GetMigUUIDFromIndex(uuid string, idx int) string { } func GetMigGpuInstanceIdFromIndex(uuid string, idx int) (int, error) { + defer nvml.Shutdown() if nvret := nvml.Init(); nvret != nvml.SUCCESS { klog.Errorln("nvml Init err: ", nvret) return 0, fmt.Errorf("nvml Init err: %s", nvml.ErrorString(nvret)) diff --git a/pkg/device-plugin/nvidiadevice/nvinternal/rm/health.go b/pkg/device-plugin/nvidiadevice/nvinternal/rm/health.go index 7f9c80eef..6cd67c3f2 100644 --- a/pkg/device-plugin/nvidiadevice/nvinternal/rm/health.go +++ b/pkg/device-plugin/nvidiadevice/nvinternal/rm/health.go @@ -63,7 +63,7 @@ const ( ) // CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices -func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhealthy chan<- *Device, disableNVML <-chan bool) error { +func (r *nvmlResourceManager) checkHealth(stop <-chan any, unhealthy chan<- *Device, disableNVML <-chan bool) error { klog.V(4).Info("Check Health start Running") disableHealthChecks := strings.ToLower(os.Getenv(envDisableHealthChecks)) if disableHealthChecks == "all" { @@ -73,20 +73,6 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe return nil } - ret := r.nvml.Init() - if ret != nvml.SUCCESS { - if *r.config.Flags.FailOnInitError { - return fmt.Errorf("failed to initialize NVML: %v", ret) - } - return nil - } - defer func() { - ret := r.nvml.Shutdown() - if ret != nvml.SUCCESS { - klog.Infof("Error shutting down NVML: %v", ret) - } - }() - // FIXME: formalize the full list and document it. // http://docs.nvidia.com/deploy/xid-errors/index.html#topic_4 // Application errors: the GPU should still be healthy @@ -107,55 +93,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe skippedXids[additionalXid] = true } - eventSet, ret := r.nvml.EventSetCreate() - if ret != nvml.SUCCESS { - return fmt.Errorf("failed to create event set: %v", ret) - } - defer eventSet.Free() - - parentToDeviceMap := make(map[string]*Device) - deviceIDToGiMap := make(map[string]int) - deviceIDToCiMap := make(map[string]int) - eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError) - for _, d := range devices { - uuid, gi, ci, err := r.getDevicePlacement(d) - if err != nil { - klog.Warningf("Could not determine device placement for %v: %v; Marking it unhealthy.", d.ID, err) - d.Health = kubeletdevicepluginv1beta1.Unhealthy - unhealthy <- d - continue - } - deviceIDToGiMap[d.ID] = gi - deviceIDToCiMap[d.ID] = ci - parentToDeviceMap[uuid] = d - - gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid) - if ret != nvml.SUCCESS { - klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret) - d.Health = kubeletdevicepluginv1beta1.Unhealthy - unhealthy <- d - continue - } - - supportedEvents, ret := gpu.GetSupportedEventTypes() - if ret != nvml.SUCCESS { - klog.Infof("Unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret) - d.Health = kubeletdevicepluginv1beta1.Unhealthy - unhealthy <- d - continue - } - - ret = gpu.RegisterEvents(eventMask&supportedEvents, eventSet) - if ret == nvml.ERROR_NOT_SUPPORTED { - klog.Warningf("Device %v is too old to support healthchecking.", d.ID) - } - if ret != nvml.SUCCESS { - klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret) - d.Health = kubeletdevicepluginv1beta1.Unhealthy - unhealthy <- d - } - } // Track consecutive NVML event errors to avoid flapping successiveEventErrorCount := 0 @@ -167,16 +105,69 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe // Track consecutive timeouts (no new XID errors) for XID recovery stableTimeoutCount := 0 - for { - select { - case <-stop: + checkLoop := func() error { + ret := r.nvml.Init() + if ret != nvml.SUCCESS { + if *r.config.Flags.FailOnInitError { + return fmt.Errorf("failed to initialize NVML: %v", ret) + } return nil - case signal := <-disableNVML: - if signal { - klog.Info("Check Health has been received close signal") - return fmt.Errorf("close signal received") + } + defer func() { + ret := r.nvml.Shutdown() + if ret != nvml.SUCCESS { + klog.Infof("Error shutting down NVML: %v", ret) + } + }() + + eventSet, ret := r.nvml.EventSetCreate() + if ret != nvml.SUCCESS { + return fmt.Errorf("failed to create event set: %v", ret) + } + defer eventSet.Free() + + parentToDeviceMap := make(map[string]*Device) + deviceIDToGiMap := make(map[string]int) + deviceIDToCiMap := make(map[string]int) + + devices := r.devicesSnapshot() + for _, d := range devices { + uuid, gi, ci, err := r.getDevicePlacement(d) + if err != nil { + klog.Warningf("Could not determine device placement for %v: %v; Marking it unhealthy.", d.ID, err) + d.Health = kubeletdevicepluginv1beta1.Unhealthy + unhealthy <- d + continue + } + deviceIDToGiMap[d.ID] = gi + deviceIDToCiMap[d.ID] = ci + parentToDeviceMap[uuid] = d + + gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret) + d.Health = kubeletdevicepluginv1beta1.Unhealthy + unhealthy <- d + continue + } + + supportedEvents, ret := gpu.GetSupportedEventTypes() + if ret != nvml.SUCCESS { + klog.Infof("Unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret) + d.Health = kubeletdevicepluginv1beta1.Unhealthy + unhealthy <- d + continue + } + + ret = gpu.RegisterEvents(eventMask&supportedEvents, eventSet) + if ret == nvml.ERROR_NOT_SUPPORTED { + klog.Warningf("Device %v is too old to support healthchecking.", d.ID) + } + if ret != nvml.SUCCESS { + klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret) + d.Health = kubeletdevicepluginv1beta1.Unhealthy + unhealthy <- d } - default: } e, ret := eventSet.Wait(5000) @@ -210,7 +201,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe stableTimeoutCount = 0 } } - continue + return nil } if ret != nvml.SUCCESS { successiveEventErrorCount++ @@ -224,7 +215,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe } } } - continue + return nil } // Successful event received, reset error counter. // Recovery is handled by the timeout branch once NVML wait stabilizes without errors. @@ -234,12 +225,12 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe if e.EventType != nvml.EventTypeXidCriticalError { klog.Infof("Skipping non-nvmlEventTypeXidCriticalError event: %+v", e) - continue + return nil } if skippedXids[e.EventData] { klog.Infof("Skipping event %+v", e) - continue + return nil } klog.Infof("Processing event %+v", e) @@ -253,13 +244,13 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe xidMarked[d.ID] = true } stableTimeoutCount = 0 - continue + return nil } d, exists := parentToDeviceMap[eventUUID] if !exists { klog.Infof("Ignoring event for unexpected device: %v", eventUUID) - continue + return nil } if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF { @@ -267,11 +258,11 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe ci := deviceIDToCiMap[d.ID] giu32, err := safecast.ToUint32(gi) if err != nil || giu32 != e.GpuInstanceId { - continue + return nil } ciu32, err := safecast.ToUint32(ci) if err != nil || ciu32 != e.ComputeInstanceId { - continue + return nil } klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci) } @@ -282,6 +273,25 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan any, devices Devices, unhe // Track device for potential recovery and reset stability counter xidMarked[d.ID] = true stableTimeoutCount = 0 + return nil + } + + for { + select { + case <-stop: + return nil + case signal := <-disableNVML: + if signal { + klog.Info("Check Health has been received close signal") + return fmt.Errorf("close signal received") + } + default: + err := checkLoop() + if err != nil { + return err + } + } + } } diff --git a/pkg/device-plugin/nvidiadevice/nvinternal/rm/nvml_manager.go b/pkg/device-plugin/nvidiadevice/nvinternal/rm/nvml_manager.go index be1f2915b..0008ff163 100644 --- a/pkg/device-plugin/nvidiadevice/nvinternal/rm/nvml_manager.go +++ b/pkg/device-plugin/nvidiadevice/nvinternal/rm/nvml_manager.go @@ -34,6 +34,8 @@ package rm import ( "fmt" + "sync" + "time" "github.com/Project-HAMi/HAMi/pkg/device/nvidia" @@ -44,6 +46,10 @@ import ( type nvmlResourceManager struct { resourceManager nvml nvml.Interface + + mu sync.RWMutex + lastRescan time.Time + rescanInterval time.Duration } var _ ResourceManager = (*nvmlResourceManager)(nil) @@ -86,6 +92,7 @@ func NewNVMLResourceManagers(nvmllib nvml.Interface, config *nvidia.DeviceConfig }, nvml: nvmllib, } + r.rescanInterval = 30 * time.Second rms = append(rms, r) } @@ -114,12 +121,23 @@ func (r *nvmlResourceManager) GetDevicePaths(ids []string) []string { return paths } +// Devices returns a snapshot of devices for this resource. +// +// It also performs a throttled rescan to detect hot-plug/hot-unplug events. +// Thread-safety rules: +// - the internal map is always protected by r.mu +// - callers get a shallow copy, so external iteration can't race with internal updates +func (r *nvmlResourceManager) Devices() Devices { + r.maybeRescan() + return r.devicesSnapshot() +} + // CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices func (r *nvmlResourceManager) CheckHealth(stop <-chan any, unhealthy chan<- *Device, disableNVML <-chan bool, ackDisableHealthChecks chan<- bool) error { for { // first check if disableNVML channel signal is pass close into checkHealth function // if signal is pass close, return error "close signal received" - err := r.checkHealth(stop, r.devices, unhealthy, disableNVML) + err := r.checkHealth(stop, unhealthy, disableNVML) if err.Error() == "close signal received" { ackDisableHealthChecks <- true klog.Info("Check Health has been closed") @@ -133,3 +151,107 @@ func (r *nvmlResourceManager) CheckHealth(stop <-chan any, unhealthy chan<- *Dev return err } } + +func (r *nvmlResourceManager) devicesSnapshot() Devices { + r.mu.RLock() + defer r.mu.RUnlock() + return copyDevicesMap(r.resourceManager.devices) +} + +func copyDevicesMap(in Devices) Devices { + out := make(Devices, len(in)) + for id, dev := range in { + out[id] = dev + } + return out +} + +func (r *nvmlResourceManager) maybeRescan() { + // Fast path: check without lock. + if r.rescanInterval > 0 && !r.lastRescan.IsZero() && time.Since(r.lastRescan) < r.rescanInterval { + return + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Re-check under lock (double-checked locking). + if r.rescanInterval > 0 && !r.lastRescan.IsZero() && time.Since(r.lastRescan) < r.rescanInterval { + return + } + + if err := r.rescanLocked(); err != nil { + // Rescan failures should not break the device plugin; log and keep last known-good devices. + klog.ErrorS(err, "Failed to rescan NVML devices; keeping existing device list", "resource", r.resource) + return + } + r.lastRescan = time.Now() +} + +func (r *nvmlResourceManager) rescanLocked() error { + ret := r.nvml.Init() + if ret != nvml.SUCCESS { + if r.config != nil && r.config.Flags.FailOnInitError != nil && *r.config.Flags.FailOnInitError { + return fmt.Errorf("failed to initialize NVML for rescan: %v", ret) + } + return nil + } + defer func() { + ret := r.nvml.Shutdown() + if ret != nvml.SUCCESS { + klog.Infof("Error shutting down NVML after rescan: %v", ret) + } + }() + + newDeviceMap, err := NewDeviceMap(r.nvml, r.config) + if err != nil { + return fmt.Errorf("error building device map during rescan: %v", err) + } + + newDevices, exists := newDeviceMap[r.resource] + if !exists { + newDevices = make(Devices) + } + + for key, value := range newDevices { + if nvidia.FilterDeviceToRegister(value.ID, value.Index) { + klog.V(5).InfoS("Filtering device during rescan", "device", value.ID) + delete(newDevices, key) + } + } + + // Merge: preserve existing *Device pointers to keep Health state. + oldDevices := r.resourceManager.devices + if oldDevices == nil { + oldDevices = make(Devices) + } + + // Add/update. + for id, newDev := range newDevices { + if old, ok := oldDevices[id]; ok && old != nil { + // Preserve health, but refresh metadata that may change across rescans. + old.Paths = newDev.Paths + old.Index = newDev.Index + old.Topology = newDev.Topology + continue + } + oldDevices[id] = newDev + klog.InfoS("Hot-plug: new device detected", "resource", r.resource, "deviceID", id, "index", newDev.Index) + } + + // Remove. + for id, old := range oldDevices { + if _, ok := newDevices[id]; ok { + continue + } + if old != nil { + klog.InfoS("Hot-unplug: device removed", "resource", r.resource, "deviceID", id, "index", old.Index) + } else { + klog.InfoS("Hot-unplug: device removed", "resource", r.resource, "deviceID", id) + } + delete(oldDevices, id) + } + + r.resourceManager.devices = oldDevices + return nil +} diff --git a/pkg/scheduler/config/config.go b/pkg/scheduler/config/config.go index 20545e8f3..5351396fd 100644 --- a/pkg/scheduler/config/config.go +++ b/pkg/scheduler/config/config.go @@ -45,6 +45,10 @@ var ( // NodeLockTimeout is the timeout for node locks. NodeLockTimeout time.Duration + // CleanupStartupDelay is the delay before cleanup loops start running after scheduler startup. + // This prevents aggressive cleanup during cluster/component start up, when the nvidia-driver has not initialized plugged-in devices. + CleanupStartupDelay time.Duration + // If set to false, When Pod.Spec.SchedulerName equals to the const DefaultSchedulerName in k8s.io/api/core/v1 package, webhook will not overwrite it, default value is true. ForceOverwriteDefaultScheduler bool ) diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 6553f18cb..0a8f99cc0 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -280,6 +280,15 @@ func (s *Scheduler) RegisterFromNodeAnnotations() { } func (s *Scheduler) CleanupGPUBindingsLoop() { + klog.InfoS("CleanupGPUBindingsLoop: delaying start", "delay", config.CleanupStartupDelay) + timer := time.NewTimer(config.CleanupStartupDelay) + defer timer.Stop() + select { + case <-timer.C: + case <-s.stopCh: + return + } + klog.InfoS("Starting CleanupGPUBindingsLoop") defer klog.InfoS("Exiting CleanupGPUBindingsLoop") ticker := time.NewTicker(15 * time.Second) @@ -400,6 +409,15 @@ func (s *Scheduler) CleanupGPUBindingsLoop() { // CleanupPodsWithMissingDevicesLoop periodically cleans up pods that are assigned // devices which no longer exist in the cluster. func (s *Scheduler) CleanupPodsWithMissingDevicesLoop() { + klog.InfoS("CleanupPodsWithMissingDevicesLoop: delaying start", "delay", config.CleanupStartupDelay) + timer := time.NewTimer(config.CleanupStartupDelay) + defer timer.Stop() + select { + case <-timer.C: + case <-s.stopCh: + return + } + klog.InfoS("Starting CleanupPodsWithMissingDevicesLoop") defer klog.InfoS("Exiting CleanupPodsWithMissingDevicesLoop") ticker := time.NewTicker(30 * time.Second) @@ -436,7 +454,7 @@ func (s *Scheduler) CleanupPodsWithMissingDevicesLoop() { podsToDelete := make([]*podInfo, 0) for _, pod := range scheduledPods { - if pod.Devices == nil || len(pod.Devices) == 0 { + if len(pod.Devices) == 0 { continue }