@@ -31,6 +31,7 @@ import (
3131 "github.com/NVIDIA/go-nvlib/pkg/nvmdev"
3232 "github.com/NVIDIA/go-nvlib/pkg/nvpci"
3333 devchar "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/system/create-dev-char-symlinks"
34+ "github.com/moby/sys/symlink"
3435 log "github.com/sirupsen/logrus"
3536 "github.com/stretchr/testify/assert/yaml"
3637 cli "github.com/urfave/cli/v3"
@@ -135,6 +136,7 @@ var (
135136 hostRootFlag string
136137 driverInstallDirFlag string
137138 driverInstallDirCtrPathFlag string
139+ hostRootCtrPath = "/host"
138140)
139141
140142// defaultGPUWorkloadConfig is "vm-passthrough" unless
@@ -742,22 +744,43 @@ func isDriverManagedByOperator(ctx context.Context) (bool, error) {
742744 return false , nil
743745}
744746
747+ func resolveHostNvidiaSMIPath () (string , error ) {
748+ nvidiaSMIPath , err := symlink .FollowSymlinkInScope (filepath .Join (hostRootCtrPath , "usr/bin/nvidia-smi" ), hostRootCtrPath )
749+ if err != nil {
750+ return "" , fmt .Errorf ("failed to resolve 'nvidia-smi' path on the host: %w" , err )
751+ }
752+
753+ fileInfo , err := os .Lstat (nvidiaSMIPath )
754+ if err != nil {
755+ return "" , fmt .Errorf ("no 'nvidia-smi' file present on the host: %w" , err )
756+ }
757+ if fileInfo .Size () == 0 {
758+ return "" , fmt .Errorf ("empty 'nvidia-smi' file found on the host" )
759+ }
760+
761+ return nvidiaSMIPath , nil
762+ }
763+
764+ func hostDriverValidationCommand () (string , []string , error ) {
765+ if _ , err := resolveHostNvidiaSMIPath (); err != nil {
766+ return "" , nil , err
767+ }
768+
769+ return "chroot" , []string {hostRootCtrPath , "nvidia-smi" }, nil
770+ }
771+
745772func validateHostDriver (silent bool ) error {
746773 log .Info ("Attempting to validate a pre-installed driver on the host" )
747- if fileInfo , err := os .Lstat (filepath .Join ("/host" , wslNvidiaSMIPath )); err == nil && fileInfo .Size () != 0 {
774+ if fileInfo , err := os .Lstat (filepath .Join (hostRootCtrPath , wslNvidiaSMIPath )); err == nil && fileInfo .Size () != 0 {
748775 log .Infof ("WSL2 system detected, assuming driver is pre-installed" )
749776 disableDevCharSymlinkCreation = true
750777 return nil
751778 }
752- fileInfo , err := os .Lstat ("/host/usr/bin/nvidia-smi" )
779+
780+ command , args , err := hostDriverValidationCommand ()
753781 if err != nil {
754- return fmt .Errorf ("no 'nvidia-smi' file present on the host: %w" , err )
755- }
756- if fileInfo .Size () == 0 {
757- return fmt .Errorf ("empty 'nvidia-smi' file found on the host" )
782+ return err
758783 }
759- command := "chroot"
760- args := []string {"/host" , "nvidia-smi" }
761784
762785 return runCommand (command , args , silent )
763786}
@@ -1747,17 +1770,21 @@ func (v *VGPUManager) validate() error {
17471770 return nil
17481771}
17491772
1750- func (v * VGPUManager ) runValidation (silent bool ) (hostDriver bool , err error ) {
1751- // invoke validation command
1773+ func vGPUManagerValidationCommand () (bool , string , []string ) {
17521774 command := "chroot"
17531775 args := []string {"/run/nvidia/driver" , "nvidia-smi" }
17541776
1755- // check if driver is pre-installed on the host and use host path for validation
1756- if _ , err := os .Lstat ("/host/usr/bin/nvidia-smi" ); err == nil {
1757- args = []string {"/host" , "nvidia-smi" }
1758- hostDriver = true
1777+ if _ , err := resolveHostNvidiaSMIPath (); err == nil {
1778+ return true , command , []string {hostRootCtrPath , "nvidia-smi" }
17591779 }
17601780
1781+ return false , command , args
1782+ }
1783+
1784+ func (v * VGPUManager ) runValidation (silent bool ) (hostDriver bool , err error ) {
1785+ // invoke validation command
1786+ hostDriver , command , args := vGPUManagerValidationCommand ()
1787+
17611788 if withWaitFlag {
17621789 return hostDriver , runCommandWithWait (command , args , sleepIntervalSecondsFlag , silent )
17631790 }
0 commit comments