From d394262cf910ea2076f9443153924f325acb29c0 Mon Sep 17 00:00:00 2001 From: zhangxr <944702164@qq.com> Date: Wed, 25 Mar 2026 10:31:30 +0800 Subject: [PATCH] feat: Enhanced features: Added linkage with the hami open source project, supporting dynamic adjustment of GPU devices Signed-off-by: zhangxr <944702164@qq.com> --- cmd/nvidia-device-plugin/main.go | 111 +++++++++++++++++++++++++++++++ internal/plugin/api.go | 1 + internal/plugin/server.go | 19 ++++++ internal/rm/allocate.go | 5 +- internal/rm/nvml_manager.go | 9 +-- internal/rm/rm.go | 42 +++++++++++- internal/rm/rm_mock.go | 41 ++++++++++++ internal/rm/tegra_manager.go | 7 +- 8 files changed, 224 insertions(+), 11 deletions(-) diff --git a/cmd/nvidia-device-plugin/main.go b/cmd/nvidia-device-plugin/main.go index 7aee63fb1..fab51e925 100644 --- a/cmd/nvidia-device-plugin/main.go +++ b/cmd/nvidia-device-plugin/main.go @@ -17,11 +17,14 @@ package main import ( + "context" "encoding/json" "errors" "fmt" "os" "path/filepath" + "regexp" + "strings" "syscall" "time" @@ -30,6 +33,12 @@ import ( "github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/fsnotify/fsnotify" "github.com/urfave/cli/v2" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" "k8s.io/klog/v2" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" @@ -397,9 +406,111 @@ func startPlugins(c *cli.Context, o *options) ([]plugin.Interface, bool, error) klog.Info("No devices found. Waiting indefinitely.") } + go watchHamiNodeAnnotations(c.Context, plugins) + return plugins, false, nil } +func watchHamiNodeAnnotations(ctx context.Context, plugins []plugin.Interface) { + const annotationKey = "hami.io/node-nvidia-register" + + cfg, err := rest.InClusterConfig() + if err != nil { + klog.Warningf("node-watcher: unable to build in-cluster config: %v; skipping annotation watcher", err) + return + } + + cs, err := kubernetes.NewForConfig(cfg) + if err != nil { + klog.Warningf("node-watcher: unable to create clientset: %v; skipping annotation watcher", err) + return + } + + nodeName := os.Getenv("NODE_NAME") + if nodeName == "" { + hn, _ := os.Hostname() + nodeName = hn + klog.Warningf("node-watcher: NODE_NAME not set, falling back to hostname=%s", nodeName) + } + + re := regexp.MustCompile(`(?:hami-core:)?GPU-[0-9a-fA-F-]+`) + lastAnn := "" + + notifyPlugins := func(ann string) { + if ann == lastAnn { + return + } + lastAnn = ann + + matches := re.FindAllString(ann, -1) + seen := map[string]bool{} + var uuids []string + for _, m := range matches { + m = strings.TrimPrefix(m, "hami-core:") + if !seen[m] { + seen[m] = true + uuids = append(uuids, m) + } + } + + klog.Infof("node-watcher: annotation changed, extracted GPUs=%v", uuids) + for _, p := range plugins { + p.HandleAllowedDeviceIDs(uuids) + } + } + + // Initial sync so plugin state is aligned before watch events arrive. + if node, err := cs.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}); err == nil { + ann := "" + if node.Annotations != nil { + ann = node.Annotations[annotationKey] + } + notifyPlugins(ann) + } else { + klog.Warningf("node-watcher: initial get node %s failed: %v", nodeName, err) + } + + for { + if err := ctx.Err(); err != nil { + klog.Infof("node-watcher: context cancelled, exiting: %v", err) + return + } + + timeoutSeconds := int64(300) + w, err := cs.CoreV1().Nodes().Watch(ctx, metav1.ListOptions{ + FieldSelector: fields.OneTermEqualSelector("metadata.name", nodeName).String(), + TimeoutSeconds: &timeoutSeconds, + }) + if err != nil { + klog.Warningf("node-watcher: failed to watch node %s: %v", nodeName, err) + time.Sleep(2 * time.Second) + continue + } + + for event := range w.ResultChan() { + switch event.Type { + case watch.Added, watch.Modified: + node, ok := event.Object.(*corev1.Node) + if !ok || node == nil { + continue + } + ann := "" + if node.Annotations != nil { + ann = node.Annotations[annotationKey] + } + notifyPlugins(ann) + case watch.Deleted: + notifyPlugins("") + case watch.Error: + klog.Warningf("node-watcher: received watch error event for node %s", nodeName) + } + } + + w.Stop() + time.Sleep(1 * time.Second) + } +} + func stopPlugins(plugins []plugin.Interface) error { klog.Info("Stopping plugins.") var errs error diff --git a/internal/plugin/api.go b/internal/plugin/api.go index 92cfa2ecb..3e10bc1fc 100644 --- a/internal/plugin/api.go +++ b/internal/plugin/api.go @@ -23,4 +23,5 @@ type Interface interface { Devices() rm.Devices Start(string) error Stop() error + HandleAllowedDeviceIDs([]string) } diff --git a/internal/plugin/server.go b/internal/plugin/server.go index 8d089fca6..fe0ceedae 100644 --- a/internal/plugin/server.go +++ b/internal/plugin/server.go @@ -63,6 +63,7 @@ type nvidiaDevicePlugin struct { server *grpc.Server health chan *rm.Device stop chan interface{} + update chan struct{} imexChannels imex.Channels @@ -110,13 +111,16 @@ func (plugin *nvidiaDevicePlugin) initialize() { plugin.server = grpc.NewServer([]grpc.ServerOption{}...) plugin.health = make(chan *rm.Device) plugin.stop = make(chan interface{}) + plugin.update = make(chan struct{}, 1) } func (plugin *nvidiaDevicePlugin) cleanup() { close(plugin.stop) + close(plugin.update) plugin.server = nil plugin.health = nil plugin.stop = nil + plugin.update = nil } // Devices returns the full set of devices associated with the plugin. @@ -280,10 +284,25 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil { return nil } + case <-plugin.update: + if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil { + return nil + } } } } +// HandleAllowedDeviceIDs updates the resource manager with a list of UUIDs +// that should be excluded from plugin reporting and triggers an immediate +// ListAndWatch update to kubelet. +func (plugin *nvidiaDevicePlugin) HandleAllowedDeviceIDs(uuids []string) { + plugin.rm.HandleAllowedDeviceIDs(uuids) + select { + case plugin.update <- struct{}{}: + default: + } +} + // GetPreferredAllocation returns the preferred allocation from the set of devices specified in the request func (plugin *nvidiaDevicePlugin) GetPreferredAllocation(ctx context.Context, r *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { response := &pluginapi.PreferredAllocationResponse{} diff --git a/internal/rm/allocate.go b/internal/rm/allocate.go index 166b68e84..1bb2d878c 100644 --- a/internal/rm/allocate.go +++ b/internal/rm/allocate.go @@ -25,8 +25,9 @@ import ( // devices are distributed across all replicated GPUs equally. It takes into // account already allocated replicas to ensure a proper balance across them. func (r *resourceManager) distributedAlloc(available, required []string, size int) ([]string, error) { + devices := r.Devices() // Get the set of candidate devices as the difference between available and required. - candidates := r.devices.Subset(available).Difference(r.devices.Subset(required)).GetIDs() + candidates := devices.Subset(available).Difference(devices.Subset(required)).GetIDs() needed := size - len(required) if len(candidates) < needed { @@ -43,7 +44,7 @@ func (r *resourceManager) distributedAlloc(available, required []string, size in } replicas[id].available++ } - for d := range r.devices { + for d := range devices { id := AnnotatedID(d).GetID() if _, exists := replicas[id]; !exists { continue diff --git a/internal/rm/nvml_manager.go b/internal/rm/nvml_manager.go index fac923429..8528e9510 100644 --- a/internal/rm/nvml_manager.go +++ b/internal/rm/nvml_manager.go @@ -60,9 +60,10 @@ func NewNVMLResourceManagers(infolib info.Interface, nvmllib nvml.Interface, dev } r := &nvmlResourceManager{ resourceManager: resourceManager{ - config: config, - resource: resourceName, - devices: devices, + config: config, + resource: resourceName, + allDevices: devices, + devices: devices, }, nvml: nvmllib, } @@ -92,7 +93,7 @@ func (r *nvmlResourceManager) GetDevicePaths(ids []string) []string { // CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices func (r *nvmlResourceManager) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error { - return r.checkHealth(stop, r.devices, unhealthy) + return r.checkHealth(stop, r.Devices(), unhealthy) } // getPreferredAllocation runs an allocation algorithm over the inputs. diff --git a/internal/rm/rm.go b/internal/rm/rm.go index 33f44b9d8..97be1bb7c 100644 --- a/internal/rm/rm.go +++ b/internal/rm/rm.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "strings" + "sync" "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" @@ -33,7 +34,9 @@ import ( type resourceManager struct { config *spec.Config resource spec.ResourceName - devices Devices + allDevices Devices + devices Devices + mu sync.RWMutex } // ResourceManager provides an interface for listing a set of Devices and checking health on them @@ -42,6 +45,7 @@ type resourceManager struct { type ResourceManager interface { Resource() spec.ResourceName Devices() Devices + HandleAllowedDeviceIDs([]string) GetDevicePaths([]string) []string GetPreferredAllocation(available, required []string, size int) ([]string, error) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error @@ -55,18 +59,52 @@ func (r *resourceManager) Resource() spec.ResourceName { // Devices gets the devices managed by the ResourceManager func (r *resourceManager) Devices() Devices { + r.mu.RLock() + defer r.mu.RUnlock() return r.devices } +// HandleAllowedDeviceIDs updates the exposed device set by excluding the +// supplied GPU UUIDs (from HAMI node annotation). An empty list restores all +// discovered devices. +func (r *resourceManager) HandleAllowedDeviceIDs(uuids []string) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.allDevices == nil { + return + } + + if len(uuids) == 0 { + r.devices = r.allDevices + return + } + + excluded := make(map[string]struct{}, len(uuids)) + for _, id := range uuids { + excluded[id] = struct{}{} + } + + filtered := make(Devices) + for id, d := range r.allDevices { + if _, ok := excluded[d.GetUUID()]; ok { + continue + } + filtered[id] = d + } + r.devices = filtered +} + var errInvalidRequest = errors.New("invalid request") // ValidateRequest checks the requested IDs against the resource manager configuration. // It asserts that all requested IDs are known to the resource manager and that the request is // valid for a specified sharing configuration. func (r *resourceManager) ValidateRequest(ids AnnotatedIDs) error { + devices := r.Devices() // Assert that all requested IDs are known to the resource manager for _, id := range ids { - if !r.devices.Contains(id) { + if !devices.Contains(id) { return fmt.Errorf("%w: unknown device: %s", errInvalidRequest, id) } } diff --git a/internal/rm/rm_mock.go b/internal/rm/rm_mock.go index 4efee5fd9..54e92b51a 100644 --- a/internal/rm/rm_mock.go +++ b/internal/rm/rm_mock.go @@ -50,6 +50,9 @@ type ResourceManagerMock struct { // DevicesFunc mocks the Devices method. DevicesFunc func() Devices + // HandleAllowedDeviceIDsFunc mocks the HandleAllowedDeviceIDs method. + HandleAllowedDeviceIDsFunc func(uuids []string) + // GetDevicePathsFunc mocks the GetDevicePaths method. GetDevicePathsFunc func(strings []string) []string @@ -79,6 +82,11 @@ type ResourceManagerMock struct { // Strings is the strings argument value. Strings []string } + // HandleAllowedDeviceIDs holds details about calls to the HandleAllowedDeviceIDs method. + HandleAllowedDeviceIDs []struct { + // UUIDs is the uuids argument value. + UUIDs []string + } // GetPreferredAllocation holds details about calls to the GetPreferredAllocation method. GetPreferredAllocation []struct { // Available is the available argument value. @@ -100,6 +108,7 @@ type ResourceManagerMock struct { lockCheckHealth sync.RWMutex lockDevices sync.RWMutex lockGetDevicePaths sync.RWMutex + lockHandleAllowedDeviceIDs sync.RWMutex lockGetPreferredAllocation sync.RWMutex lockResource sync.RWMutex lockValidateRequest sync.RWMutex @@ -174,6 +183,38 @@ func (mock *ResourceManagerMock) DevicesCalls() []struct { return calls } +// HandleAllowedDeviceIDs calls HandleAllowedDeviceIDsFunc. +func (mock *ResourceManagerMock) HandleAllowedDeviceIDs(uuids []string) { + callInfo := struct { + UUIDs []string + }{ + UUIDs: uuids, + } + mock.lockHandleAllowedDeviceIDs.Lock() + mock.calls.HandleAllowedDeviceIDs = append(mock.calls.HandleAllowedDeviceIDs, callInfo) + mock.lockHandleAllowedDeviceIDs.Unlock() + if mock.HandleAllowedDeviceIDsFunc == nil { + return + } + mock.HandleAllowedDeviceIDsFunc(uuids) +} + +// HandleAllowedDeviceIDsCalls gets all the calls that were made to HandleAllowedDeviceIDs. +// Check the length with: +// +// len(mockedResourceManager.HandleAllowedDeviceIDsCalls()) +func (mock *ResourceManagerMock) HandleAllowedDeviceIDsCalls() []struct { + UUIDs []string +} { + var calls []struct { + UUIDs []string + } + mock.lockHandleAllowedDeviceIDs.RLock() + calls = mock.calls.HandleAllowedDeviceIDs + mock.lockHandleAllowedDeviceIDs.RUnlock() + return calls +} + // GetDevicePaths calls GetDevicePathsFunc. func (mock *ResourceManagerMock) GetDevicePaths(strings []string) []string { callInfo := struct { diff --git a/internal/rm/tegra_manager.go b/internal/rm/tegra_manager.go index 65ca2022f..1b939926e 100644 --- a/internal/rm/tegra_manager.go +++ b/internal/rm/tegra_manager.go @@ -47,9 +47,10 @@ func NewTegraResourceManagers(config *spec.Config) ([]ResourceManager, error) { } r := &tegraResourceManager{ resourceManager: resourceManager{ - config: config, - resource: resourceName, - devices: devices, + config: config, + resource: resourceName, + allDevices: devices, + devices: devices, }, } if len(devices) != 0 {