Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions internal/discover/graphics.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
if err != nil {
return nil, fmt.Errorf("failed to get driver version: %w", err)
}
cudaLibRoot, err := driver.GetDriverLibDirectory()
cudaLibRoots, err := driver.GetDriverLibDirectories()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err)
}
Expand Down Expand Up @@ -152,7 +152,7 @@ func newGraphicsLibrariesDiscoverer(logger logger.Interface, driver *root.Driver
lookup.NewFileLocator(
lookup.WithLogger(logger),
lookup.WithRoot(driver.Root),
lookup.WithSearchPaths(buildXOrgSearchPaths(cudaLibRoot)...),
lookup.WithSearchPaths(buildXOrgSearchPaths(cudaLibRoots...)...),
lookup.WithCount(1),
),
driver.Root,
Expand Down Expand Up @@ -239,8 +239,16 @@ func (d graphicsDriverLibraries) isDriverLibrary(filename string, libraryName st
return match
}

func buildXOrgSearchPaths(roots ...string) []string {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment needs to be moved above this method?

// buildXOrgSearchPaths returns the ordered list of search paths for XOrg files.

var paths []string
for _, root := range roots {
paths = append(paths, buildXOrgSearchPathsForSingle(root)...)
}
return paths
}

// buildXOrgSearchPaths returns the ordered list of search paths for XOrg files.
func buildXOrgSearchPaths(libRoot string) []string {
func buildXOrgSearchPathsForSingle(libRoot string) []string {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the comment so it matches the new method name?

var paths []string
if libRoot != "" {
paths = append(paths,
Expand Down
10 changes: 6 additions & 4 deletions internal/lookup/root/cuda_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ func TestLocate(t *testing.T) {
WithDriverRoot(driverRoot),
)

driverLibraryPath, err := l.GetDriverLibDirectory()
require.ErrorIs(t, err, tc.expectedError)
driverLibraryPath, err := l.GetDriverLibDirectories()
if tc.expectedError != nil {
require.ErrorIs(t, err, tc.expectedError)
return
}

// NOTE: We need to strip `/private` on MacOs due to symlink resolution
stripped := strings.TrimPrefix(driverLibraryPath, "/private")

stripped := strings.TrimPrefix(driverLibraryPath[0], "/private")
require.Equal(t, tc.expected, stripped)
})
}
Expand Down
62 changes: 41 additions & 21 deletions internal/lookup/root/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"sync"

Expand All @@ -43,8 +44,8 @@ type Driver struct {

// version caches the driver version.
version string
// driverLibDirectory caches the path to parent of the driver libraries
driverLibDirectory string
// driverLibDirectories caches the paths to parent of the driver libraries
driverLibDirectories []string
}

// New creates a new Driver root using the specified options.
Expand All @@ -70,13 +71,13 @@ func New(opts ...Option) *Driver {
}

d := &Driver{
logger: o.logger,
Root: o.Root,
DevRoot: o.DevRoot,
librarySearchPaths: o.librarySearchPaths,
configSearchPaths: o.configSearchPaths,
version: driverVersion,
driverLibDirectory: "",
logger: o.logger,
Root: o.Root,
DevRoot: o.DevRoot,
librarySearchPaths: o.librarySearchPaths,
configSearchPaths: o.configSearchPaths,
version: driverVersion,
driverLibDirectories: nil,
}

return d
Expand All @@ -97,34 +98,36 @@ func (r *Driver) Version() (string, error) {
return r.version, nil
}

// GetDriverLibDirectory returns the cached directory where the driver libs are
// found if possible.
// GetDriverLibDirectories returns the cached directories where the driver libs
// are found if possible.
// If this has not yet been initialized, the path is first detected and then returned.
func (r *Driver) GetDriverLibDirectory() (string, error) {
func (r *Driver) GetDriverLibDirectories() ([]string, error) {
r.Lock()
defer r.Unlock()

if r.driverLibDirectory == "" {
if len(r.driverLibDirectories) == 0 {
if err := r.updateInfo(); err != nil {
return "", err
return nil, err
}
}

return r.driverLibDirectory, nil
return r.driverLibDirectories, nil
}

func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator, error) {
libcudasoParentDirPath, err := r.GetDriverLibDirectory()
libcudasoParentDirPaths, err := r.GetDriverLibDirectories()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent directory: %w", err)
}

searchPaths := []string{libcudasoParentDirPath}
searchPaths := slices.Clone(libcudasoParentDirPaths)
for _, dir := range additionalDirs {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already search these paths for libcuda.so.RM_VERSION and now we're specifying them again. I agree that we can make this more robust, but I would rather use a specific set of directories here than broadening the set again.

As a follow-up question: Could we plumb through the other directories through additionalDirs instead of hardcoding them here?

if strings.HasPrefix(dir, "/") {
searchPaths = append(searchPaths, dir)
} else {
searchPaths = append(searchPaths, filepath.Join(libcudasoParentDirPath, dir))
for _, libcudasoParentDirPath := range libcudasoParentDirPaths {
searchPaths = append(searchPaths, filepath.Join(libcudasoParentDirPath, dir))
}
}
}

Expand All @@ -141,16 +144,33 @@ func (r *Driver) DriverLibraryLocator(additionalDirs ...string) (lookup.Locator,
}

func (r *Driver) updateInfo() error {
driverLibPath, version, err := r.inferVersion()
_, version, err := r.inferVersion()
if err != nil {
return err
}
if r.version != "" && r.version != version {
return fmt.Errorf("unexpected version detected: %v != %v", r.version, version)
}

versionedDriverLibPaths, err := r.Libraries().Locate("lib*.so." + version)
if err != nil {
return fmt.Errorf("failed to locate versioned driver libraries: %w", err)
}

var uniqueDirs []string
seen := make(map[string]bool)

for _, path := range versionedDriverLibPaths {
dir := filepath.Dir(path)
if seen[dir] {
continue
}
seen[dir] = true
uniqueDirs = append(uniqueDirs, r.RelativeToRoot(dir))
}

r.version = version
r.driverLibDirectory = r.RelativeToRoot(filepath.Dir(driverLibPath))
r.driverLibDirectories = uniqueDirs

return nil
}
Expand All @@ -167,7 +187,7 @@ func (r *Driver) inferVersion() (string, string, error) {
for _, driverLib := range []string{"libcuda.so.", "libnvidia-ml.so."} {
driverLibPaths, err := r.Libraries().Locate(driverLib + versionSuffix)
if err != nil {
errs = errors.Join(errs, fmt.Errorf("failed to locate libcuda.so: %w", err))
errs = errors.Join(errs, fmt.Errorf("failed to locate %v: %w", driverLib, err))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
errs = errors.Join(errs, fmt.Errorf("failed to locate %v: %w", driverLib, err))
errs = errors.Join(errs, fmt.Errorf("failed to locate %q: %w", driverLib, err))

continue
}
driverLibPath := driverLibPaths[0]
Expand Down
4 changes: 2 additions & 2 deletions pkg/nvcdi/driver-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
disableDeviceNodeModification := l.hookCreator.Create(DisableDeviceNodeModificationHook)
discoverers = append(discoverers, disableDeviceNodeModification)

driverLibDirectory, err := l.driver.GetDriverLibDirectory()
driverLibDirectories, err := l.driver.GetDriverLibDirectories()
if err != nil {
return nil, fmt.Errorf("failed to get libcuda.so parent directory path: %w", err)
}
environmentVariable := &discover.EnvVar{
Name: "NVIDIA_CTK_LIBCUDA_DIR",
Value: driverLibDirectory,
Value: strings.Join(driverLibDirectories, string(filepath.ListSeparator)),
}
discoverers = append(discoverers, environmentVariable)

Expand Down