Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package nri

import (
"context"
"strings"

"github.com/containerd/nri/pkg/api"
nrilog "github.com/containerd/nri/pkg/log"
"github.com/containerd/nri/pkg/plugin"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

type cdiInjectorPlugin struct {
logger nrilog.Logger
namespace string
}

func NewCDIDeviceInjector(logger logger.Interface, namespace string) interface{} {
return &cdiInjectorPlugin{
logger: toNriLogger{
logger,
},
namespace: namespace,
}
}

// CreateContainer handles container creation requests.
func (c *cdiInjectorPlugin) CreateContainer(ctx context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
adjust := &api.ContainerAdjustment{}

if err := c.injectCDIDevices(ctx, pod, ctr, adjust); err != nil {
return nil, nil, err
}

return adjust, nil, nil
}

func (c *cdiInjectorPlugin) injectCDIDevices(ctx context.Context, pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {

devices := c.parseCDIDevices(ctx, pod, nriCDIAnnotationDomain, ctr.Name)
if len(devices) == 0 {
c.logger.Debugf(ctx, "%s: no CDI devices annotated...", containerName(pod, ctr))
return nil
}

c.logger.Infof(ctx, "%s: injecting CDI devices %v...", containerName(pod, ctr), devices)
for _, name := range devices {
a.AddCDIDevice(
&api.CDIDevice{
Name: name,
},
)
}

return nil
}

// parseCDIDevices processes the podSpec and determines which containers which need CDI devices injected to them
func (c *cdiInjectorPlugin) parseCDIDevices(ctx context.Context, pod *api.PodSandbox, key, container string) []string {
if c.namespace != pod.Namespace {
c.logger.Debugf(ctx, "pod %s/%s is not in the toolkit's namespace %s. Skipping CDI device injection...", pod.Namespace, pod.Name, c.namespace)
return nil
}

cdiDeviceNames, ok := plugin.GetEffectiveAnnotation(pod, key, container)
if !ok {
return nil
}

cdiDevices := strings.Split(cdiDeviceNames, ",")
return cdiDevices
}

// Construct a container name for log messages.
func containerName(pod *api.PodSandbox, container *api.Container) string {
if pod != nil {
return pod.Name + "/" + container.Name
}
return container.Name
}
67 changes: 67 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nri

import (
"fmt"
)

const (
maxPluginIndex = 99
)

// RegistrationConfig holds NRI registration settings shared by all plugins.
type RegistrationConfig struct {
// Index is the plugin index registered with NRI (0-99).
Index uint
// Socket is the path to the NRI socket. When empty, the NRI default is used.
Socket string
}

// ValidateEntries checks that each entry is usable and that plugin indices are unique.
func ValidateEntries(entries []Entry) error {
if len(entries) == 0 {
return nil
}

seenIndex := make(map[uint]string, len(entries))
seenName := make(map[string]struct{}, len(entries))
for i, entry := range entries {
if len(entry.Name) == 0 {
return fmt.Errorf("nri plugin %d: name must be specified", i)
}
if entry.PluginRunner == nil {
return fmt.Errorf("nri plugin %q: implementation must be specified", entry.Name)
}
if entry.Config.Index > maxPluginIndex {
return fmt.Errorf("nri plugin %q: index must be in the range [0,%d]", entry.Name, maxPluginIndex)
}
if other, ok := seenIndex[entry.Config.Index]; ok {
return fmt.Errorf("nri plugin %q: duplicate plugin index %d (already used by %q)", entry.Name, entry.Config.Index, other)
}
seenIndex[entry.Config.Index] = entry.Name
if _, ok := seenName[entry.Name]; ok {
return fmt.Errorf("nri plugin %q: duplicate plugin name", entry.Name)
}
seenName[entry.Name] = struct{}{}
}
return nil
}

func (c RegistrationConfig) pluginIndex() string {
return fmt.Sprintf("%02d", c.Index)
}
112 changes: 112 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nri

import (
"context"
"strings"
"testing"
)

type stubRunner struct{}

func (stubRunner) Start(context.Context, RegistrationConfig) error { return nil }
func (stubRunner) Stop() {}

func TestValidateEntries(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
entries []Entry
wantErr string
}{
{
name: "empty",
entries: nil,
},
{
name: "valid single",
entries: []Entry{
{Name: "management", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 10}},
},
},
{
name: "valid multiple",
entries: []Entry{
{Name: "management", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 10}},
{Name: "other", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 11}},
},
},
{
name: "missing name",
entries: []Entry{
{PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 10}},
},
wantErr: "name must be specified",
},
{
name: "missing runner",
entries: []Entry{
{Name: "management", Config: RegistrationConfig{Index: 10}},
},
wantErr: "implementation must be specified",
},
{
name: "index out of range",
entries: []Entry{
{Name: "management", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 100}},
},
wantErr: "index must be in the range",
},
{
name: "duplicate index",
entries: []Entry{
{Name: "first", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 10}},
{Name: "second", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 10}},
},
wantErr: "duplicate plugin index",
},
{
name: "duplicate name",
entries: []Entry{
{Name: "management", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 10}},
{Name: "management", PluginRunner: stubRunner{}, Config: RegistrationConfig{Index: 11}},
},
wantErr: "duplicate plugin name",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
err := ValidateEntries(tc.entries)
if tc.wantErr == "" {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
return
}
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Fatalf("error %q does not contain %q", err.Error(), tc.wantErr)
}
})
}
}
32 changes: 32 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/entry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nri

import "context"

// Runner is an NRI plugin implementation that can be registered with the container runtime.
type Runner interface {
Start(ctx context.Context, cfg RegistrationConfig) error
Stop()
}

// Entry associates a user-defined name and implementation with NRI registration settings.
type Entry struct {
Name string
Config RegistrationConfig
PluginRunner Runner
}
69 changes: 69 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package nri

import (
"context"
"fmt"

"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// Manager owns a set of running NRI plugin instances.
type Manager struct {
logger logger.Interface
plugins []Runner
}

// NewManager creates a manager for starting and stopping NRI plugins.
func NewManager(log logger.Interface) *Manager {
return &Manager{
logger: log,
}
}

// Start initializes each entry. Already-started plugins are stopped if a later
// plugin fails to start.
func (m *Manager) Start(ctx context.Context, entries []Entry) error {
if err := ValidateEntries(entries); err != nil {
return err
}
if len(entries) == 0 {
return nil
}

m.logger.Infof("Starting %d NRI plugin(s)...", len(entries))
for _, entry := range entries {
m.logger.Infof("Starting NRI plugin %q (index=%02d, socket=%s)...",
entry.Name, entry.Config.Index, entry.Config.Socket)

if err := entry.PluginRunner.Start(ctx, entry.Config); err != nil {
m.Stop()
return fmt.Errorf("nri plugin %q (index=%02d): %w", entry.Name, entry.Config.Index, err)
}
m.plugins = append(m.plugins, entry.PluginRunner)
}
return nil
}

// Stop stops all running NRI plugins.
func (m *Manager) Stop() {
for _, plugin := range m.plugins {
plugin.Stop()
}
m.plugins = nil
}
Loading
Loading