From 68fa4bbc2308b03c54dee15cc1cd92ed541bdd56 Mon Sep 17 00:00:00 2001 From: Arjun Date: Wed, 29 Apr 2026 15:48:28 +0000 Subject: [PATCH 1/2] Add support for requirements checks to CDI Signed-off-by: Arjun --- internal/modifier/cdi.go | 8 + internal/modifier/csv.go | 37 +---- internal/modifier/image_requirements.go | 200 ++++++++++++++++++++++++ 3 files changed, 209 insertions(+), 36 deletions(-) create mode 100644 internal/modifier/image_requirements.go diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 005c42f48..260005925 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -51,6 +51,14 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) { defaultKind, ) devices := deviceRequestor.DeviceRequests() + + // Run before the empty-device return so NVIDIA_REQUIRE_* is still enforced when + // len(devices)==0 (e.g. CRI CDI injection without matching spec signals). When + // there are no requirements, checkRequirements returns immediately. + if err := checkRequirements(f.logger, f.image, f.driver); err != nil { + return nil, fmt.Errorf("requirements not met: %w", err) + } + if len(devices) == 0 { f.logger.Debugf("No devices requested; no modification required.") return nil, nil diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index b20fdb134..2b93d6041 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -20,10 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" - "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" - "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" - "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" ) // newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. @@ -36,45 +33,13 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) { } f.logger.Infof("Constructing modifier from config: %+v", *f.cfg) - if err := checkRequirements(f.logger, f.image); err != nil { + if err := checkRequirements(f.logger, f.image, f.driver); err != nil { return nil, fmt.Errorf("requirements not met: %v", err) } return f.newAutomaticCDISpecModifier(devices) } -func checkRequirements(logger logger.Interface, image *image.CUDA) error { - if image == nil || image.HasDisableRequire() { - // TODO: We could print the real value here instead - logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true) - return nil - } - - imageRequirements, err := image.GetRequirements() - if err != nil { - // TODO: Should we treat this as a failure, or just issue a warning? - return fmt.Errorf("failed to get image requirements: %v", err) - } - - r := requirements.New(logger, imageRequirements) - - cudaVersion, err := cuda.Version() - if err != nil { - logger.Warningf("Failed to get CUDA version: %v", err) - } else { - r.AddVersionProperty(requirements.CUDA, cudaVersion) - } - - compteCapability, err := cuda.ComputeCapability(0) - if err != nil { - logger.Warningf("Failed to get CUDA Compute Capability: %v", err) - } else { - r.AddVersionProperty(requirements.ARCH, compteCapability) - } - - return r.Assert() -} - type csvDevices image.CUDA func (d csvDevices) DeviceRequests() []string { diff --git a/internal/modifier/image_requirements.go b/internal/modifier/image_requirements.go new file mode 100644 index 000000000..b36ff4480 --- /dev/null +++ b/internal/modifier/image_requirements.go @@ -0,0 +1,200 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package modifier + +import ( + "fmt" + "strconv" + "strings" + + "github.com/NVIDIA/go-nvml/pkg/nvml" + "golang.org/x/mod/semver" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" + "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" +) + +// checkRequirements evaluates NVIDIA_REQUIRE_* constraints using the host +// CUDA driver API version from libcuda, the NVIDIA display driver version from +// the driver root (libcuda / libnvidia-ml soname), the compute capability of +// CUDA device 0, and (when requirements reference brand) the GPU product brand +// from NVML. It is used for CSV and CDI / JIT-CDI modes. +func checkRequirements(logger logger.Interface, image *image.CUDA, driver *root.Driver) error { + if image == nil || image.HasDisableRequire() { + logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true) + return nil + } + + imageRequirements, err := image.GetRequirements() + if err != nil { + return fmt.Errorf("failed to get image requirements: %v", err) + } + if len(imageRequirements) == 0 { + return nil + } + + r := requirements.New(logger, imageRequirements) + + cudaVersion, err := cuda.Version() + if err != nil { + logger.Warningf("Failed to get CUDA version: %v", err) + } else { + r.AddVersionProperty(requirements.CUDA, cudaVersion) + } + + compteCapability, err := cuda.ComputeCapability(0) + if err != nil { + logger.Warningf("Failed to get CUDA Compute Capability: %v", err) + } else { + r.AddVersionProperty(requirements.ARCH, compteCapability) + } + + driverVersion, err := driver.Version() + if err != nil { + logger.Warningf("Failed to get NVIDIA driver version: %v", err) + } else { + normalized, normErr := normalizeDriverVersionForSemver(driverVersion) + if normErr != nil { + logger.Warningf("NVIDIA driver version %q is not semver-normalizable: %v", driverVersion, normErr) + } else { + r.AddVersionProperty(requirements.DRIVER, normalized) + } + } + + brand, err := getBrandFromNVML(driver) + if err != nil { + logger.Warningf("Failed to get GPU brand from NVML: %v", err) + } else { + r.AddStringProperty(requirements.BRAND, brand) + } + + return r.Assert() +} + +// normalizeDriverVersionForSemver converts a driver version taken from a +// libcuda / libnvidia-ml soname suffix into a form accepted by +// golang.org/x/mod/semver (no leading zeros in numeric segments) +func normalizeDriverVersionForSemver(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", fmt.Errorf("empty driver version") + } + parts := strings.Split(raw, ".") + out := make([]string, 0, len(parts)) + for _, p := range parts { + if p == "" { + return "", fmt.Errorf("empty version segment in %q", raw) + } + if strings.TrimLeft(p, "0123456789") != "" { + return "", fmt.Errorf("non-numeric version segment %q in %q", p, raw) + } + n, err := strconv.ParseUint(p, 10, 64) + if err != nil { + return "", fmt.Errorf("invalid version segment %q in %q: %w", p, raw, err) + } + out = append(out, strconv.FormatUint(n, 10)) + } + normalized := strings.Join(out, ".") + if !semver.IsValid("v" + normalized) { + return "", fmt.Errorf("normalized driver version %q is not valid semver", normalized) + } + return normalized, nil +} + +// getBrandFromNVML returns a lowercase brand token for the first visible GPU +// (index 0), using NVML. When driver is non-nil, NVML is loaded from the +// versioned libnvidia-ml under the driver root when possible. +func getBrandFromNVML(driver *root.Driver) (string, error) { + var lib nvml.Interface + var opts []nvml.LibraryOption + v, err := driver.Version() + if err == nil && v != "" && v != "*.*" { + paths, err := driver.Libraries().Locate("libnvidia-ml.so." + v) + if err == nil && len(paths) > 0 { + opts = append(opts, nvml.WithLibraryPath(paths[0])) + } + } + + lib = nvml.New(opts...) + if ret := lib.Init(); ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.Init: %s", lib.ErrorString(ret)) + } + defer func() { + _ = lib.Shutdown() + }() + + device, ret := lib.DeviceGetHandleByIndex(0) + if ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.DeviceGetHandleByIndex(0): %s", lib.ErrorString(ret)) + } + + brandType, ret := lib.DeviceGetBrand(device) + if ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.DeviceGetBrand: %s", lib.ErrorString(ret)) + } + brand, ok := brandTypeToRequirementString(brandType) + if !ok { + return "", fmt.Errorf("unknown NVML brand type %v", brandType) + } + return brand, nil +} + +// brandTypeToRequirementString maps NVML brand enums to lowercase tokens +// consistent with typical NVIDIA_REQUIRE_* image constraints. +func brandTypeToRequirementString(b nvml.BrandType) (string, bool) { + switch b { + case nvml.BRAND_UNKNOWN: + return "", false + case nvml.BRAND_QUADRO: + return "quadro", true + case nvml.BRAND_TESLA: + return "tesla", true + case nvml.BRAND_NVS: + return "nvs", true + case nvml.BRAND_GRID: + return "grid", true + case nvml.BRAND_GEFORCE: + return "geforce", true + case nvml.BRAND_TITAN: + return "titan", true + case nvml.BRAND_NVIDIA_VAPPS: + return "nvidiavapps", true + case nvml.BRAND_NVIDIA_VPC: + return "nvidiavpc", true + case nvml.BRAND_NVIDIA_VCS: + return "nvidiavcs", true + case nvml.BRAND_NVIDIA_VWS: + return "nvidiavws", true + case nvml.BRAND_NVIDIA_CLOUD_GAMING: + return "nvidiacloudgaming", true + case nvml.BRAND_QUADRO_RTX: + return "quadrortx", true + case nvml.BRAND_NVIDIA_RTX: + return "nvidiartx", true + case nvml.BRAND_NVIDIA: + return "nvidia", true + case nvml.BRAND_GEFORCE_RTX: + return "geforcertx", true + case nvml.BRAND_TITAN_RTX: + return "titanrtx", true + default: + return "", false + } +} From f90d1ff6edbc1bf6f8c2bfdc73872bf6ec96fbad Mon Sep 17 00:00:00 2001 From: Arjun Date: Mon, 18 May 2026 22:03:14 +0000 Subject: [PATCH 2/2] Switched to hook format --- .../check-requirements/check-requirements.go | 120 +++++++++++ cmd/nvidia-cdi-hook/commands/commands.go | 2 + internal/discover/hooks.go | 3 + internal/discover/requirements.go | 43 ++++ internal/modifier/cdi.go | 3 +- internal/modifier/csv.go | 3 +- internal/requirements/image.go | 198 ++++++++++++++++++ pkg/nvcdi/api.go | 2 + pkg/nvcdi/lib.go | 6 + pkg/nvcdi/wrapper.go | 15 ++ 10 files changed, 393 insertions(+), 2 deletions(-) create mode 100644 cmd/nvidia-cdi-hook/check-requirements/check-requirements.go create mode 100644 internal/discover/requirements.go create mode 100644 internal/requirements/image.go diff --git a/cmd/nvidia-cdi-hook/check-requirements/check-requirements.go b/cmd/nvidia-cdi-hook/check-requirements/check-requirements.go new file mode 100644 index 000000000..d53d86365 --- /dev/null +++ b/cmd/nvidia-cdi-hook/check-requirements/check-requirements.go @@ -0,0 +1,120 @@ +/** +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package checkrequirements + +import ( + "context" + "encoding/json" + "fmt" + "os" + + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/urfave/cli/v3" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" +) + +type command struct { + logger logger.Interface +} + +type options struct { + containerSpec string + driverRoot string +} + +// NewCommand constructs a check-requirements command with the specified logger. +func NewCommand(logger logger.Interface) *cli.Command { + c := command{ + logger: logger, + } + return c.build() +} + +func (m command) build() *cli.Command { + cfg := options{} + + return &cli.Command{ + Name: "check-requirements", + Usage: "Check NVIDIA_REQUIRE_* constraints from the container image", + Action: func(_ context.Context, _ *cli.Command) error { + return m.run(&cfg) + }, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "driver-root", + Usage: "Specify the NVIDIA GPU driver root to use when detecting host properties", + Destination: &cfg.driverRoot, + }, + &cli.StringFlag{ + Name: "container-spec", + Hidden: true, + Category: "testing-only", + Usage: "Specify the path to the OCI container state. If empty or '-' the state will be read from STDIN", + Destination: &cfg.containerSpec, + }, + }, + } +} + +func (m command) run(cfg *options) error { + cudaImage, err := loadCUDAImageFromState(cfg.containerSpec, m.logger) + if err != nil { + return fmt.Errorf("failed to load CUDA image from container state: %w", err) + } + + driver := root.New( + root.WithLogger(m.logger), + root.WithDriverRoot(cfg.driverRoot), + ) + if err := requirements.CheckImage(m.logger, cudaImage, driver); err != nil { + return fmt.Errorf("requirements not met: %w", err) + } + return nil +} + +func loadCUDAImageFromState(containerStatePath string, logger logger.Interface) (*image.CUDA, error) { + state, err := oci.LoadContainerState(containerStatePath) + if err != nil { + return nil, fmt.Errorf("failed to load container state: %w", err) + } + + specFilePath := oci.GetSpecFilePath(state.Bundle) + specFile, err := os.Open(specFilePath) + if err != nil { + return nil, fmt.Errorf("failed to open OCI spec file: %w", err) + } + defer specFile.Close() + + var spec specs.Spec + if err := json.NewDecoder(specFile).Decode(&spec); err != nil { + return nil, fmt.Errorf("failed to decode OCI spec: %w", err) + } + + cudaImage, err := image.NewCUDAImageFromSpec( + &spec, + image.WithLogger(logger), + ) + if err != nil { + return nil, err + } + return &cudaImage, nil +} diff --git a/cmd/nvidia-cdi-hook/commands/commands.go b/cmd/nvidia-cdi-hook/commands/commands.go index a9f6c05ce..72cda51bf 100644 --- a/cmd/nvidia-cdi-hook/commands/commands.go +++ b/cmd/nvidia-cdi-hook/commands/commands.go @@ -23,6 +23,7 @@ import ( "github.com/urfave/cli/v3" "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod" + checkrequirements "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/check-requirements" symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks" "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat" disabledevicenodemodification "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/disable-device-node-modification" @@ -85,6 +86,7 @@ func ConfigureCDIHookCommand(logger logger.Interface, base *cli.Command) *cli.Co // Define the supported hooks. base.Commands = []*cli.Command{ + checkrequirements.NewCommand(logger), ldcache.NewCommand(logger), symlinks.NewCommand(logger), chmod.NewCommand(logger), diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go index fb16f8587..629e2927b 100644 --- a/internal/discover/hooks.go +++ b/internal/discover/hooks.go @@ -34,6 +34,9 @@ const ( // // Deprecated: The chmod hook is deprecated and will be removed in a future release. ChmodHook = HookName("chmod") + // A CheckRequirementsHook is used to enforce NVIDIA_REQUIRE_* constraints + // from the container image. + CheckRequirementsHook = HookName("check-requirements") // A CreateSymlinksHook is used to create symlinks in the container. CreateSymlinksHook = HookName("create-symlinks") // DisableDeviceNodeModificationHook refers to the hook used to ensure that diff --git a/internal/discover/requirements.go b/internal/discover/requirements.go new file mode 100644 index 000000000..e672aa9b6 --- /dev/null +++ b/internal/discover/requirements.go @@ -0,0 +1,43 @@ +/** +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package discover + +// CheckRequirementsHookOptions defines the options that can be specified when +// creating the check-requirements hook. +type CheckRequirementsHookOptions struct { + DriverRoot string +} + +// NewCheckRequirementsHookDiscoverer creates a discoverer for a +// check-requirements hook. +func NewCheckRequirementsHookDiscoverer(hookCreator HookCreator, o *CheckRequirementsHookOptions) Discover { + hook := hookCreator.Create(CheckRequirementsHook, o.args()...) + if hook == nil { + return None{} + } + return hook +} + +func (o *CheckRequirementsHookOptions) args() []string { + if o == nil { + return nil + } + if o.DriverRoot == "" || o.DriverRoot == "/" { + return nil + } + return []string{"--driver-root=" + o.DriverRoot} +} diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 260005925..b444c614b 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -26,6 +26,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" ) @@ -55,7 +56,7 @@ func (f *Factory) newCDIModifier(isJitCDI bool) (oci.SpecModifier, error) { // Run before the empty-device return so NVIDIA_REQUIRE_* is still enforced when // len(devices)==0 (e.g. CRI CDI injection without matching spec signals). When // there are no requirements, checkRequirements returns immediately. - if err := checkRequirements(f.logger, f.image, f.driver); err != nil { + if err := requirements.CheckImage(f.logger, f.image, f.driver); err != nil { return nil, fmt.Errorf("requirements not met: %w", err) } diff --git a/internal/modifier/csv.go b/internal/modifier/csv.go index 2b93d6041..3b7a51dcb 100644 --- a/internal/modifier/csv.go +++ b/internal/modifier/csv.go @@ -21,6 +21,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/NVIDIA/nvidia-container-toolkit/internal/requirements" ) // newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. @@ -33,7 +34,7 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) { } f.logger.Infof("Constructing modifier from config: %+v", *f.cfg) - if err := checkRequirements(f.logger, f.image, f.driver); err != nil { + if err := requirements.CheckImage(f.logger, f.image, f.driver); err != nil { return nil, fmt.Errorf("requirements not met: %v", err) } diff --git a/internal/requirements/image.go b/internal/requirements/image.go new file mode 100644 index 000000000..61489b178 --- /dev/null +++ b/internal/requirements/image.go @@ -0,0 +1,198 @@ +/** +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package requirements + +import ( + "fmt" + "strconv" + "strings" + + "github.com/NVIDIA/go-nvml/pkg/nvml" + "golang.org/x/mod/semver" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/cuda" + "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root" +) + +// CheckImage evaluates NVIDIA_REQUIRE_* constraints using host properties. +func CheckImage(logger logger.Interface, cudaImage *image.CUDA, driver *root.Driver) error { + if cudaImage == nil || cudaImage.HasDisableRequire() { + logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", true) + return nil + } + + imageRequirements, err := cudaImage.GetRequirements() + if err != nil { + return fmt.Errorf("failed to get image requirements: %v", err) + } + if len(imageRequirements) == 0 { + return nil + } + + if driver == nil { + driver = root.New(root.WithLogger(logger)) + } + + r := New(logger, imageRequirements) + + cudaVersion, err := cuda.Version() + if err != nil { + logger.Warningf("Failed to get CUDA version: %v", err) + } else { + r.AddVersionProperty(CUDA, cudaVersion) + } + + computeCapability, err := cuda.ComputeCapability(0) + if err != nil { + logger.Warningf("Failed to get CUDA Compute Capability: %v", err) + } else { + r.AddVersionProperty(ARCH, computeCapability) + } + + driverVersion, err := driver.Version() + if err != nil { + logger.Warningf("Failed to get NVIDIA driver version: %v", err) + } else { + normalized, normErr := normalizeDriverVersionForSemver(driverVersion) + if normErr != nil { + logger.Warningf("NVIDIA driver version %q is not semver-normalizable: %v", driverVersion, normErr) + } else { + r.AddVersionProperty(DRIVER, normalized) + } + } + + brand, err := getBrandFromNVML(driver) + if err != nil { + logger.Warningf("Failed to get GPU brand from NVML: %v", err) + } else { + r.AddStringProperty(BRAND, brand) + } + + return r.Assert() +} + +// normalizeDriverVersionForSemver converts a driver version taken from a +// libcuda / libnvidia-ml soname suffix into a form accepted by +// golang.org/x/mod/semver (no leading zeros in numeric segments). +func normalizeDriverVersionForSemver(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", fmt.Errorf("empty driver version") + } + parts := strings.Split(raw, ".") + out := make([]string, 0, len(parts)) + for _, p := range parts { + if p == "" { + return "", fmt.Errorf("empty version segment in %q", raw) + } + if strings.TrimLeft(p, "0123456789") != "" { + return "", fmt.Errorf("non-numeric version segment %q in %q", p, raw) + } + n, err := strconv.ParseUint(p, 10, 64) + if err != nil { + return "", fmt.Errorf("invalid version segment %q in %q: %w", p, raw, err) + } + out = append(out, strconv.FormatUint(n, 10)) + } + normalized := strings.Join(out, ".") + if !semver.IsValid("v" + normalized) { + return "", fmt.Errorf("normalized driver version %q is not valid semver", normalized) + } + return normalized, nil +} + +// getBrandFromNVML returns a lowercase brand token for the first visible GPU +// (index 0), using NVML. When driver is non-nil, NVML is loaded from the +// versioned libnvidia-ml under the driver root when possible. +func getBrandFromNVML(driver *root.Driver) (string, error) { + var opts []nvml.LibraryOption + v, err := driver.Version() + if err == nil && v != "" && v != "*.*" { + paths, err := driver.Libraries().Locate("libnvidia-ml.so." + v) + if err == nil && len(paths) > 0 { + opts = append(opts, nvml.WithLibraryPath(paths[0])) + } + } + + lib := nvml.New(opts...) + if ret := lib.Init(); ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.Init: %s", lib.ErrorString(ret)) + } + defer func() { + _ = lib.Shutdown() + }() + + device, ret := lib.DeviceGetHandleByIndex(0) + if ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.DeviceGetHandleByIndex(0): %s", lib.ErrorString(ret)) + } + + brandType, ret := lib.DeviceGetBrand(device) + if ret != nvml.SUCCESS { + return "", fmt.Errorf("nvml.DeviceGetBrand: %s", lib.ErrorString(ret)) + } + brand, ok := brandTypeToRequirementString(brandType) + if !ok { + return "", fmt.Errorf("unknown NVML brand type %v", brandType) + } + return brand, nil +} + +// brandTypeToRequirementString maps NVML brand enums to lowercase tokens +// consistent with typical NVIDIA_REQUIRE_* image constraints. +func brandTypeToRequirementString(b nvml.BrandType) (string, bool) { + switch b { + case nvml.BRAND_UNKNOWN: + return "", false + case nvml.BRAND_QUADRO: + return "quadro", true + case nvml.BRAND_TESLA: + return "tesla", true + case nvml.BRAND_NVS: + return "nvs", true + case nvml.BRAND_GRID: + return "grid", true + case nvml.BRAND_GEFORCE: + return "geforce", true + case nvml.BRAND_TITAN: + return "titan", true + case nvml.BRAND_NVIDIA_VAPPS: + return "nvidiavapps", true + case nvml.BRAND_NVIDIA_VPC: + return "nvidiavpc", true + case nvml.BRAND_NVIDIA_VCS: + return "nvidiavcs", true + case nvml.BRAND_NVIDIA_VWS: + return "nvidiavws", true + case nvml.BRAND_NVIDIA_CLOUD_GAMING: + return "nvidiacloudgaming", true + case nvml.BRAND_QUADRO_RTX: + return "quadrortx", true + case nvml.BRAND_NVIDIA_RTX: + return "nvidiartx", true + case nvml.BRAND_NVIDIA: + return "nvidia", true + case nvml.BRAND_GEFORCE_RTX: + return "geforcertx", true + case nvml.BRAND_TITAN_RTX: + return "titanrtx", true + default: + return "", false + } +} diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 651bf2f74..99cd7cb14 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -51,6 +51,8 @@ const ( // AllHooks is a special hook name that allows all hooks to be matched. AllHooks = discover.AllHooks + // A CheckRequirementsHook is used to enforce NVIDIA_REQUIRE_* constraints. + CheckRequirementsHook = discover.CheckRequirementsHook // A CreateSymlinksHook is used to create symlinks in the container. CreateSymlinksHook = discover.CreateSymlinksHook // DisableDeviceNodeModificationHook refers to the hook used to ensure that diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index fa84e19c5..ae6ed7c7c 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -100,6 +100,12 @@ func New(opts ...Option) (Interface, error) { vendor: o.getVendorOrDefault(), class: o.getClassOrDefault(), mergedDeviceOptions: o.mergedDeviceOptions, + editsFactory: l.editsFactory, + additionalCommonEdits: []discover.Discover{ + discover.NewCheckRequirementsHookDiscoverer(l.hookCreator, &discover.CheckRequirementsHookOptions{ + DriverRoot: l.driver.Root, + }), + }, } return &w, nil } diff --git a/pkg/nvcdi/wrapper.go b/pkg/nvcdi/wrapper.go index 9d4520798..4d211ff44 100644 --- a/pkg/nvcdi/wrapper.go +++ b/pkg/nvcdi/wrapper.go @@ -23,6 +23,8 @@ import ( "tags.cncf.io/container-device-interface/specs-go" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" ) @@ -34,6 +36,9 @@ type wrapper struct { class string mergedDeviceOptions []transform.MergedDeviceOption + + editsFactory edits.Factory + additionalCommonEdits []discover.Discover } // TODO: Rename this type @@ -96,6 +101,16 @@ func (m *wrapper) GetCommonEdits() (*cdi.ContainerEdits, error) { if err != nil { return nil, err } + for _, discoverer := range m.additionalCommonEdits { + if discoverer == nil { + continue + } + additionalEdits, err := m.editsFactory.FromDiscoverer(discoverer) + if err != nil { + return nil, err + } + edits.Append(additionalEdits) + } edits.Env = append(edits.Env, image.EnvVarNvidiaVisibleDevices+"=void") return edits, nil