diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index 21c9c14db..7feea2cb0 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -121,7 +121,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver if err != nil { return nil, fmt.Errorf("failed to get driver version: %w", err) } - cudaLibRoot, err := driver.GetDriverLibDirectory() + cudaLibRoots, err := driver.GetDriverLibDirectories() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err) } @@ -152,7 +152,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver lookup.NewFileLocator( lookup.WithLogger(logger), lookup.WithRoot(driver.Root), - lookup.WithSearchPaths(buildXOrgSearchPaths(cudaLibRoot)...), + lookup.WithSearchPaths(buildXOrgSearchPaths(cudaLibRoots...)...), lookup.WithCount(1), ), driver.Root, @@ -239,8 +239,16 @@ func (d graphicsDriverLibraries) isDriverLibrary(filename string, libraryName st return match } +func buildXOrgSearchPaths(roots ...string) []string { + var paths []string + for _, root := range roots { + paths = append(paths, buildXOrgSearchPathsForSingle(root)...) + } + return paths +} + // buildXOrgSearchPaths returns the ordered list of search paths for XOrg files. -func buildXOrgSearchPaths(libRoot string) []string { +func buildXOrgSearchPathsForSingle(libRoot string) []string { var paths []string if libRoot != "" { paths = append(paths, diff --git a/internal/lookup/root/cuda_test.go b/internal/lookup/root/cuda_test.go index 4f548238f..325c1806c 100644 --- a/internal/lookup/root/cuda_test.go +++ b/internal/lookup/root/cuda_test.go @@ -68,12 +68,14 @@ func TestLocate(t *testing.T) { WithDriverRoot(driverRoot), ) - driverLibraryPath, err := l.GetDriverLibDirectory() - require.ErrorIs(t, err, tc.expectedError) + driverLibraryPath, err := l.GetDriverLibDirectories() + if tc.expectedError != nil { + require.ErrorIs(t, err, tc.expectedError) + return + } // NOTE: We need to strip `/private` on MacOs due to symlink resolution - stripped := strings.TrimPrefix(driverLibraryPath, "/private") - + stripped := strings.TrimPrefix(driverLibraryPath[0], "/private") require.Equal(t, tc.expected, stripped) }) } diff --git a/internal/lookup/root/root.go b/internal/lookup/root/root.go index 3b21b24ed..d09522774 100644 --- a/internal/lookup/root/root.go +++ b/internal/lookup/root/root.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "path/filepath" + "slices" "strings" "sync" @@ -43,8 +44,8 @@ type Driver struct { // version caches the driver version. version string - // driverLibDirectory caches the path to parent of the driver libraries - driverLibDirectory string + // driverLibDirectories caches the paths to parent of the driver libraries + driverLibDirectories []string } // New creates a new Driver root using the specified options. @@ -70,13 +71,13 @@ func New(opts ...Option) *Driver { } d := &Driver{ - logger: o.logger, - Root: o.Root, - DevRoot: o.DevRoot, - librarySearchPaths: o.librarySearchPaths, - configSearchPaths: o.configSearchPaths, - version: driverVersion, - driverLibDirectory: "", + logger: o.logger, + Root: o.Root, + DevRoot: o.DevRoot, + librarySearchPaths: o.librarySearchPaths, + configSearchPaths: o.configSearchPaths, + version: driverVersion, + driverLibDirectories: nil, } return d @@ -97,34 +98,36 @@ func (r *Driver) Version() (string, error) { return r.version, nil } -// GetDriverLibDirectory returns the cached directory where the driver libs are -// found if possible. +// GetDriverLibDirectories returns the cached directories where the driver libs +// are found if possible. // If this has not yet been initialized, the path is first detected and then returned. -func (r *Driver) GetDriverLibDirectory() (string, error) { +func (r *Driver) GetDriverLibDirectories() ([]string, error) { r.Lock() defer r.Unlock() - if r.driverLibDirectory == "" { + if len(r.driverLibDirectories) == 0 { if err := r.updateInfo(); err != nil { - return "", err + return nil, err } } - return r.driverLibDirectory, nil + return r.driverLibDirectories, nil } func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, error) { - libcudasoParentDirPath, err := r.GetDriverLibDirectory() + libcudasoParentDirPaths, err := r.GetDriverLibDirectories() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err) } - searchPaths := []string{libcudasoParentDirPath} + searchPaths := slices.Clone(libcudasoParentDirPaths) for _, dir := range additionalDirs { if strings.HasPrefix(dir, "/") { searchPaths = append(searchPaths, dir) } else { - searchPaths = append(searchPaths, filepath.Join(libcudasoParentDirPath, dir)) + for _, libcudasoParentDirPath := range libcudasoParentDirPaths { + searchPaths = append(searchPaths, filepath.Join(libcudasoParentDirPath, dir)) + } } } @@ -141,7 +144,7 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, } func (r *Driver) updateInfo() error { - driverLibPath, version, err := r.inferVersion() + _, version, err := r.inferVersion() if err != nil { return err } @@ -149,8 +152,25 @@ func (r *Driver) updateInfo() error { return fmt.Errorf("unexpected version detected: %v != %v", r.version, version) } + versionedDriverLibPaths, err := r.Libraries().Locate("lib*.so." + version) + if err != nil { + return fmt.Errorf("failed to locate versioned driver libraries: %w", err) + } + + var uniqueDirs []string + seen := make(map[string]bool) + + for _, path := range versionedDriverLibPaths { + dir := filepath.Dir(path) + if seen[dir] { + continue + } + seen[dir] = true + uniqueDirs = append(uniqueDirs, r.RelativeToRoot(dir)) + } + r.version = version - r.driverLibDirectory = r.RelativeToRoot(filepath.Dir(driverLibPath)) + r.driverLibDirectories = uniqueDirs return nil } @@ -167,7 +187,7 @@ func (r *Driver) inferVersion() (string, string, error) { for _, driverLib := range []string{"libcuda.so.", "libnvidia-ml.so."} { driverLibPaths, err := r.Libraries().Locate(driverLib + versionSuffix) if err != nil { - errs = errors.Join(errs, fmt.Errorf("failed to locate libcuda.so: %w", err)) + errs = errors.Join(errs, fmt.Errorf("failed to locate %v: %w", driverLib, err)) continue } driverLibPath := driverLibPaths[0] diff --git a/pkg/nvcdi/driver-nvml.go b/pkg/nvcdi/driver-nvml.go index f672f1820..75efbfe7e 100644 --- a/pkg/nvcdi/driver-nvml.go +++ b/pkg/nvcdi/driver-nvml.go @@ -111,13 +111,13 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover disableDeviceNodeModification := l.hookCreator.Create(DisableDeviceNodeModificationHook) discoverers = append(discoverers, disableDeviceNodeModification) - driverLibDirectory, err := l.driver.GetDriverLibDirectory() + driverLibDirectories, err := l.driver.GetDriverLibDirectories() if err != nil { return nil, fmt.Errorf("failed to get libcuda.so parent directory path: %w", err) } environmentVariable := &discover.EnvVar{ Name: "NVIDIA_CTK_LIBCUDA_DIR", - Value: driverLibDirectory, + Value: strings.Join(driverLibDirectories, string(filepath.ListSeparator)), } discoverers = append(discoverers, environmentVariable)