diff --git a/cosmosutils/binary.go b/cosmosutils/binary.go index 56e41a5..76d0cc1 100644 --- a/cosmosutils/binary.go +++ b/cosmosutils/binary.go @@ -260,7 +260,7 @@ func getMinitiadBinaryURL(vm, version string) (string, error) { } // FindBinaryDir walks versionDir to find the directory that contains the named -// executable. This avoids hardcoding assumptions about how a release tarball is +// binary. This avoids hardcoding assumptions about how a release tarball is // structured, so the code stays correct even if a future tarball places the // binary inside a subdirectory. func FindBinaryDir(versionDir, binaryName string) (string, error) { @@ -269,7 +269,7 @@ func FindBinaryDir(versionDir, binaryName string) (string, error) { if err != nil { return err } - if !info.IsDir() && info.Name() == binaryName && info.Mode()&0o111 != 0 { + if !info.IsDir() && info.Name() == binaryName { result = filepath.Dir(path) return filepath.SkipAll } diff --git a/cosmosutils/binary_test.go b/cosmosutils/binary_test.go index 23b7fe0..e50708c 100644 --- a/cosmosutils/binary_test.go +++ b/cosmosutils/binary_test.go @@ -341,13 +341,17 @@ func TestFindBinaryDir(t *testing.T) { wantRel: filepath.Join("a", "b", "c"), }, { - name: "non-executable file is ignored", + name: "finds binary before executable permissions are restored", layout: func(root string) { - os.MkdirAll(root, 0o755) - os.WriteFile(filepath.Join(root, "minitiad"), []byte("data"), 0o644) + if err := os.MkdirAll(root, 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "minitiad"), []byte("data"), 0o644); err != nil { + t.Fatal(err) + } }, binaryName: "minitiad", - wantRel: "", + wantRel: ".", }, { name: "wrong name is ignored", diff --git a/io/filesystem.go b/io/filesystem.go index a5e8798..742e18a 100644 --- a/io/filesystem.go +++ b/io/filesystem.go @@ -38,6 +38,11 @@ func DownloadAndExtractTarGz(url, tarballPath, extractedPath string) error { } func ExtractTarGz(src string, dest string) error { + destRoot, err := filepath.Abs(dest) + if err != nil { + return err + } + file, err := os.Open(src) if err != nil { return err @@ -60,23 +65,27 @@ func ExtractTarGz(src string, dest string) error { return err } - target := filepath.Join(dest, header.Name) + target, err := safeArchivePath(destRoot, header.Name) + if err != nil { + return err + } switch header.Typeflag { case tar.TypeDir: if err := os.MkdirAll(target, os.ModePerm); err != nil { return err } case tar.TypeReg: - file, err := os.Create(target) - if err != nil { + if err := os.MkdirAll(filepath.Dir(target), os.ModePerm); err != nil { return err } - _, err = io.Copy(file, tarReader) + file, err := os.OpenFile(target, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(header.Mode)) if err != nil { return err } - err = file.Close() - if err != nil { + if err := writeTarFile(file, tarReader); err != nil { + return err + } + if err := os.Chmod(target, os.FileMode(header.Mode)); err != nil { return err } default: @@ -86,6 +95,34 @@ func ExtractTarGz(src string, dest string) error { return nil } +func writeTarFile(file *os.File, src io.Reader) (err error) { + defer func() { + if closeErr := file.Close(); err == nil && closeErr != nil { + err = closeErr + } + }() + + _, err = io.Copy(file, src) + return err +} + +func safeArchivePath(destRoot, entryName string) (string, error) { + cleanName := filepath.Clean(entryName) + if cleanName == "." { + return destRoot, nil + } + + target := filepath.Join(destRoot, cleanName) + rel, err := filepath.Rel(destRoot, target) + if err != nil { + return "", err + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("unsafe archive entry path: %s", entryName) + } + return target, nil +} + func SetLibraryPaths(binaryDir string) error { envKey, envValue, err := LibraryPathEnv(binaryDir) if err != nil { diff --git a/io/filesystem_test.go b/io/filesystem_test.go index ba00b17..1331fa5 100644 --- a/io/filesystem_test.go +++ b/io/filesystem_test.go @@ -1,7 +1,10 @@ package io import ( + "archive/tar" + "compress/gzip" "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -59,6 +62,69 @@ func TestExtractTarGz(t *testing.T) { err := ExtractTarGz("./invalid.tar.gz", "./invalid") assert.Error(t, err) }) + + t.Run("PreservesExtractedFileMode", func(t *testing.T) { + tmpDir := t.TempDir() + tarballPath := filepath.Join(tmpDir, "test.tar.gz") + extractDir := filepath.Join(tmpDir, "extract") + + file, err := os.Create(tarballPath) + assert.NoError(t, err) + + gzw := gzip.NewWriter(file) + tw := tar.NewWriter(gzw) + + content := []byte("#!/bin/sh\necho ok\n") + header := &tar.Header{ + Name: "minitiad", + Mode: 0o755, + Size: int64(len(content)), + Typeflag: tar.TypeReg, + } + assert.NoError(t, tw.WriteHeader(header)) + _, err = tw.Write(content) + assert.NoError(t, err) + assert.NoError(t, tw.Close()) + assert.NoError(t, gzw.Close()) + assert.NoError(t, file.Close()) + + err = ExtractTarGz(tarballPath, extractDir) + assert.NoError(t, err) + + info, err := os.Stat(filepath.Join(extractDir, "minitiad")) + assert.NoError(t, err) + assert.Equal(t, os.FileMode(0o755), info.Mode().Perm()) + }) + + t.Run("RejectsPathTraversalEntries", func(t *testing.T) { + tmpDir := t.TempDir() + tarballPath := filepath.Join(tmpDir, "test.tar.gz") + extractDir := filepath.Join(tmpDir, "extract") + + file, err := os.Create(tarballPath) + assert.NoError(t, err) + + gzw := gzip.NewWriter(file) + tw := tar.NewWriter(gzw) + + content := []byte("bad\n") + header := &tar.Header{ + Name: "../escape", + Mode: 0o644, + Size: int64(len(content)), + Typeflag: tar.TypeReg, + } + assert.NoError(t, tw.WriteHeader(header)) + _, err = tw.Write(content) + assert.NoError(t, err) + assert.NoError(t, tw.Close()) + assert.NoError(t, gzw.Close()) + assert.NoError(t, file.Close()) + + err = ExtractTarGz(tarballPath, extractDir) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsafe archive entry path") + }) } func TestSetLibraryPaths(t *testing.T) {