Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions cmd/nvidia-device-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
package main

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"
"time"

Expand All @@ -30,6 +33,12 @@ import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/fsnotify/fsnotify"
"github.com/urfave/cli/v2"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/klog/v2"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"

Expand Down Expand Up @@ -397,9 +406,111 @@ func startPlugins(c *cli.Context, o *options) ([]plugin.Interface, bool, error)
klog.Info("No devices found. Waiting indefinitely.")
}

go watchHamiNodeAnnotations(c.Context, plugins)

return plugins, false, nil
}

func watchHamiNodeAnnotations(ctx context.Context, plugins []plugin.Interface) {
const annotationKey = "hami.io/node-nvidia-register"

cfg, err := rest.InClusterConfig()
if err != nil {
klog.Warningf("node-watcher: unable to build in-cluster config: %v; skipping annotation watcher", err)
return
}

cs, err := kubernetes.NewForConfig(cfg)
if err != nil {
klog.Warningf("node-watcher: unable to create clientset: %v; skipping annotation watcher", err)
return
}

nodeName := os.Getenv("NODE_NAME")
if nodeName == "" {
hn, _ := os.Hostname()
nodeName = hn
klog.Warningf("node-watcher: NODE_NAME not set, falling back to hostname=%s", nodeName)
}

re := regexp.MustCompile(`(?:hami-core:)?GPU-[0-9a-fA-F-]+`)
lastAnn := ""

notifyPlugins := func(ann string) {
if ann == lastAnn {
return
}
lastAnn = ann

matches := re.FindAllString(ann, -1)
seen := map[string]bool{}
var uuids []string
for _, m := range matches {
m = strings.TrimPrefix(m, "hami-core:")
if !seen[m] {
seen[m] = true
uuids = append(uuids, m)
}
}

klog.Infof("node-watcher: annotation changed, extracted GPUs=%v", uuids)
for _, p := range plugins {
p.HandleAllowedDeviceIDs(uuids)
}
}

// Initial sync so plugin state is aligned before watch events arrive.
if node, err := cs.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}); err == nil {
ann := ""
if node.Annotations != nil {
ann = node.Annotations[annotationKey]
}
notifyPlugins(ann)
} else {
klog.Warningf("node-watcher: initial get node %s failed: %v", nodeName, err)
}

for {
if err := ctx.Err(); err != nil {
klog.Infof("node-watcher: context cancelled, exiting: %v", err)
return
}

timeoutSeconds := int64(300)
w, err := cs.CoreV1().Nodes().Watch(ctx, metav1.ListOptions{
FieldSelector: fields.OneTermEqualSelector("metadata.name", nodeName).String(),
TimeoutSeconds: &timeoutSeconds,
})
if err != nil {
klog.Warningf("node-watcher: failed to watch node %s: %v", nodeName, err)
time.Sleep(2 * time.Second)
continue
}

for event := range w.ResultChan() {
switch event.Type {
case watch.Added, watch.Modified:
node, ok := event.Object.(*corev1.Node)
if !ok || node == nil {
continue
}
ann := ""
if node.Annotations != nil {
ann = node.Annotations[annotationKey]
}
notifyPlugins(ann)
case watch.Deleted:
notifyPlugins("")
case watch.Error:
klog.Warningf("node-watcher: received watch error event for node %s", nodeName)
}
}

w.Stop()
time.Sleep(1 * time.Second)
}
}

func stopPlugins(plugins []plugin.Interface) error {
klog.Info("Stopping plugins.")
var errs error
Expand Down
1 change: 1 addition & 0 deletions internal/plugin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ type Interface interface {
Devices() rm.Devices
Start(string) error
Stop() error
HandleAllowedDeviceIDs([]string)
}
19 changes: 19 additions & 0 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type nvidiaDevicePlugin struct {
server *grpc.Server
health chan *rm.Device
stop chan interface{}
update chan struct{}

imexChannels imex.Channels

Expand Down Expand Up @@ -110,13 +111,16 @@ func (plugin *nvidiaDevicePlugin) initialize() {
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
plugin.health = make(chan *rm.Device)
plugin.stop = make(chan interface{})
plugin.update = make(chan struct{}, 1)
}

func (plugin *nvidiaDevicePlugin) cleanup() {
close(plugin.stop)
close(plugin.update)
plugin.server = nil
plugin.health = nil
plugin.stop = nil
plugin.update = nil
}

// Devices returns the full set of devices associated with the plugin.
Expand Down Expand Up @@ -280,10 +284,25 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
case <-plugin.update:
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
}
}
}

// HandleAllowedDeviceIDs updates the resource manager with a list of UUIDs
// that should be excluded from plugin reporting and triggers an immediate
// ListAndWatch update to kubelet.
func (plugin *nvidiaDevicePlugin) HandleAllowedDeviceIDs(uuids []string) {
plugin.rm.HandleAllowedDeviceIDs(uuids)
select {
case plugin.update <- struct{}{}:
default:
}
}

// GetPreferredAllocation returns the preferred allocation from the set of devices specified in the request
func (plugin *nvidiaDevicePlugin) GetPreferredAllocation(ctx context.Context, r *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) {
response := &pluginapi.PreferredAllocationResponse{}
Expand Down
5 changes: 3 additions & 2 deletions internal/rm/allocate.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ import (
// devices are distributed across all replicated GPUs equally. It takes into
// account already allocated replicas to ensure a proper balance across them.
func (r *resourceManager) distributedAlloc(available, required []string, size int) ([]string, error) {
devices := r.Devices()
// Get the set of candidate devices as the difference between available and required.
candidates := r.devices.Subset(available).Difference(r.devices.Subset(required)).GetIDs()
candidates := devices.Subset(available).Difference(devices.Subset(required)).GetIDs()
needed := size - len(required)

if len(candidates) < needed {
Expand All @@ -43,7 +44,7 @@ func (r *resourceManager) distributedAlloc(available, required []string, size in
}
replicas[id].available++
}
for d := range r.devices {
for d := range devices {
id := AnnotatedID(d).GetID()
if _, exists := replicas[id]; !exists {
continue
Expand Down
9 changes: 5 additions & 4 deletions internal/rm/nvml_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ func NewNVMLResourceManagers(infolib info.Interface, nvmllib nvml.Interface, dev
}
r := &nvmlResourceManager{
resourceManager: resourceManager{
config: config,
resource: resourceName,
devices: devices,
config: config,
resource: resourceName,
allDevices: devices,
devices: devices,
},
nvml: nvmllib,
}
Expand Down Expand Up @@ -92,7 +93,7 @@ func (r *nvmlResourceManager) GetDevicePaths(ids []string) []string {

// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices
func (r *nvmlResourceManager) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error {
return r.checkHealth(stop, r.devices, unhealthy)
return r.checkHealth(stop, r.Devices(), unhealthy)
}

// getPreferredAllocation runs an allocation algorithm over the inputs.
Expand Down
42 changes: 40 additions & 2 deletions internal/rm/rm.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"errors"
"fmt"
"strings"
"sync"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
Expand All @@ -33,7 +34,9 @@ import (
type resourceManager struct {
config *spec.Config
resource spec.ResourceName
devices Devices
allDevices Devices
devices Devices
mu sync.RWMutex
}

// ResourceManager provides an interface for listing a set of Devices and checking health on them
Expand All @@ -42,6 +45,7 @@ type resourceManager struct {
type ResourceManager interface {
Resource() spec.ResourceName
Devices() Devices
HandleAllowedDeviceIDs([]string)
GetDevicePaths([]string) []string
GetPreferredAllocation(available, required []string, size int) ([]string, error)
CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error
Expand All @@ -55,18 +59,52 @@ func (r *resourceManager) Resource() spec.ResourceName {

// Devices gets the devices managed by the ResourceManager
func (r *resourceManager) Devices() Devices {
r.mu.RLock()
defer r.mu.RUnlock()
return r.devices
}

// HandleAllowedDeviceIDs updates the exposed device set by excluding the
// supplied GPU UUIDs (from HAMI node annotation). An empty list restores all
// discovered devices.
func (r *resourceManager) HandleAllowedDeviceIDs(uuids []string) {
r.mu.Lock()
defer r.mu.Unlock()

if r.allDevices == nil {
return
}

if len(uuids) == 0 {
r.devices = r.allDevices
return
}

excluded := make(map[string]struct{}, len(uuids))
for _, id := range uuids {
excluded[id] = struct{}{}
}

filtered := make(Devices)
for id, d := range r.allDevices {
if _, ok := excluded[d.GetUUID()]; ok {
continue
}
filtered[id] = d
}
r.devices = filtered
}

var errInvalidRequest = errors.New("invalid request")

// ValidateRequest checks the requested IDs against the resource manager configuration.
// It asserts that all requested IDs are known to the resource manager and that the request is
// valid for a specified sharing configuration.
func (r *resourceManager) ValidateRequest(ids AnnotatedIDs) error {
devices := r.Devices()
// Assert that all requested IDs are known to the resource manager
for _, id := range ids {
if !r.devices.Contains(id) {
if !devices.Contains(id) {
return fmt.Errorf("%w: unknown device: %s", errInvalidRequest, id)
}
}
Expand Down
41 changes: 41 additions & 0 deletions internal/rm/rm_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading