Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions commands/add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ func TestAddErrors(t *testing.T) {
_, err = mockFs.Create(policy.SystemDefaultPolicyPath)

require.NoError(t, err)
addCmd = MockAddCmd(mockFs)
Comment thread
EthanHeilman marked this conversation as resolved.

policyPath, err = addCmd.Run(principal, userEmail, issuer)
require.ErrorContains(t, err, "file has insecure permissions: expected one of the following permissions [640], got (0)")
require.Empty(t, policyPath)

err = mockFs.Chmod(policy.SystemDefaultPolicyPath, 0640)
require.NoError(t, err)
Expand Down
107 changes: 47 additions & 60 deletions commands/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/fs"
"os/user"
"path/filepath"
"strings"
Expand All @@ -32,12 +31,11 @@ import (

// AuditCmd provides functionality to audit policy files against provider definitions
type AuditCmd struct {
Fs afero.Fs
Out io.Writer
ErrOut io.Writer
filePermsChecker files.PermsChecker
ProviderLoader policy.ProviderLoader
CurrentUsername string
Fs files.FileSystem
Out io.Writer
ErrOut io.Writer
ProviderLoader policy.ProviderLoader
CurrentUsername string

// Args
ProviderPath string // Custom provider file path
Expand All @@ -48,20 +46,16 @@ type AuditCmd struct {

// NewAuditCmd creates a new AuditCmd with default settings
func NewAuditCmd(out io.Writer, errOut io.Writer) *AuditCmd {
fs := afero.NewOsFs()
return &AuditCmd{
Fs: fs,
Fs: files.NewFileSystem(afero.NewOsFs()),
Out: out,
ErrOut: errOut,
ProviderLoader: policy.NewProviderFileLoader(),
CurrentUsername: getCurrentUsername(),
filePermsChecker: files.PermsChecker{
Fs: fs,
CmdRunner: files.ExecCmd,
},

ProviderPath: policy.SystemDefaultProvidersPath,
PolicyPath: policy.SystemDefaultPolicyPath,
ProviderPath: policy.SystemDefaultProvidersPath,
PolicyPath: policy.SystemDefaultPolicyPath,
SkipUserPolicy: false,
}
}

Expand Down Expand Up @@ -90,7 +84,7 @@ func (a *AuditCmd) Audit(opksshVersion string) (*TotalResults, error) {
validator := policy.NewPolicyValidator(providerPolicy)

// Audit policy file
systemResults, exists, err := a.auditPolicyFileWithStatus(policyPath, []fs.FileMode{files.ModeSystemPerms}, validator)
systemResults, exists, err := a.auditPolicyFileWithStatus(policyPath, files.RequiredPerms.SystemPolicy, validator)
if err != nil {
return nil, fmt.Errorf("failed to audit policy file: %v", err)
}
Expand All @@ -105,39 +99,30 @@ func (a *AuditCmd) Audit(opksshVersion string) (*TotalResults, error) {
}
}

// Audit user policy file if it exists and not skipping
// Audit user policy files if not skipping
if !a.SkipUserPolicy {
// We read /etc/passwd to enumerate all the home directories to find auth_id policy files.
var etcPasswdContent []byte
passwdPath := "/etc/passwd"
if exists, err := afero.Exists(a.Fs, passwdPath); !exists {
return nil, fmt.Errorf("failed to read /etc/passwd: /etc/passwd not found (needed to enumerate user home policies)")
} else if err != nil {
return nil, fmt.Errorf("failed to read /etc/passwd: %v", err)
homeDirs, err := a.enumerateUserHomeDirs()
if err != nil {
fmt.Fprintf(a.ErrOut, "warning: could not enumerate user home directories: %v\n", err)
} else {
etcPasswdContent, err = afero.ReadFile(a.Fs, passwdPath)
if err != nil {
return nil, fmt.Errorf("failed to read /etc/passwd: %v", err)
}
}
homeDirs := getHomeDirsFromEtcPasswd(string(etcPasswdContent))
for _, row := range homeDirs {
userPolicyPath := filepath.Join(row.HomeDir, ".opk", "auth_id")

userResults, userExists, err := a.auditPolicyFileWithStatus(userPolicyPath, []fs.FileMode{files.ModeHomePerms}, validator)
if err != nil {
fmt.Fprintf(a.ErrOut, "failed to audit user policy file at %s: %v\n", userPolicyPath, err)
totalResults.HomePolicyFiles = append(totalResults.HomePolicyFiles,
PolicyFileResult{FilePath: userPolicyPath, Error: err.Error()})
// Don't fail completely if user policy is unreadable
} else if userExists {
fmt.Fprintf(a.ErrOut, "\nvalidating %s...\n", userPolicyPath)
if !a.JsonOutput {
for _, result := range userResults.Rows {
a.printResult(result)
for _, row := range homeDirs {
userPolicyPath := filepath.Join(row.HomeDir, ".opk", "auth_id")

userResults, userExists, err := a.auditPolicyFileWithStatus(userPolicyPath, files.RequiredPerms.HomePolicy, validator)
if err != nil {
fmt.Fprintf(a.ErrOut, "failed to audit user policy file at %s: %v\n", userPolicyPath, err)
totalResults.HomePolicyFiles = append(totalResults.HomePolicyFiles,
PolicyFileResult{FilePath: userPolicyPath, Error: err.Error()})
// Don't fail completely if user policy is unreadable
} else if userExists {
fmt.Fprintf(a.ErrOut, "\nvalidating %s...\n", userPolicyPath)
if !a.JsonOutput {
for _, result := range userResults.Rows {
a.printResult(result)
}
}
totalResults.HomePolicyFiles = append(totalResults.HomePolicyFiles, *userResults)
}
totalResults.HomePolicyFiles = append(totalResults.HomePolicyFiles, *userResults)
}
}
}
Expand Down Expand Up @@ -189,29 +174,31 @@ func (a *AuditCmd) Run(opksshVersion string) error {
}

// auditPolicyFileWithStatus validates all entries in a policy file and returns results, whether file exists, and any errors
func (a *AuditCmd) auditPolicyFileWithStatus(policyPath string, requiredPerms []fs.FileMode, validator *policy.PolicyValidator) (*PolicyFileResult, bool, error) {
func (a *AuditCmd) auditPolicyFileWithStatus(policyPath string, permInfo files.PermInfo, validator *policy.PolicyValidator) (*PolicyFileResult, bool, error) {
results := &PolicyFileResult{
FilePath: policyPath,
Rows: []policy.ValidationRowResult{},
}

// Check if file exists
exists, err := afero.Exists(a.Fs, policyPath)
if err != nil {
return nil, false, fmt.Errorf("failed to check if policy file exists: %w", err)
// Use shared permission checking logic
permResult := CheckFilePermissions(a.Fs, policyPath, permInfo)
if !permResult.Exists {
return results, false, nil
}

if !exists {
// File doesn't exist, return empty results with exists=false
return results, false, nil
if permResult.PermsErr != "" {
results.PermsError = permResult.PermsErr
}

if permsErr := a.filePermsChecker.CheckPerm(policyPath, requiredPerms, "", ""); permsErr != nil {
results.PermsError = permsErr.Error()
// Report ACL problems to stderr
if permResult.ACLReport != nil && permResult.ACLErr == nil {
for _, problem := range permResult.ACLReport.Problems {
fmt.Fprintf(a.ErrOut, " ACL issue: %s\n", problem)
}
}

// Load policy file
content, err := afero.ReadFile(a.Fs, policyPath)
content, err := a.Fs.ReadFile(policyPath)
if err != nil {
return nil, true, fmt.Errorf("failed to read policy file: %w", err)
}
Expand Down Expand Up @@ -306,16 +293,16 @@ func getCurrentUsername() string {
return u.Username
}

type etcPasswdRow struct {
type userHomeEntry struct {
Username string
HomeDir string
}

// getHomeDirsFromEtcPasswd parses /etc/passwd and returns a list of usernames
// and their associated home directories. This is not sufficient for all home
// directories as it does not consider home directories specified by NSS.
func getHomeDirsFromEtcPasswd(etcPasswd string) []etcPasswdRow {
entries := []etcPasswdRow{}
func getHomeDirsFromEtcPasswd(etcPasswd string) []userHomeEntry {
entries := []userHomeEntry{}
for _, line := range strings.Split(etcPasswd, "\n") {
if line == "" || strings.HasPrefix(line, "#") {
continue
Expand All @@ -330,7 +317,7 @@ func getHomeDirsFromEtcPasswd(etcPasswd string) []etcPasswdRow {
continue
}

entry := etcPasswdRow{Username: parts[0], HomeDir: parts[5]}
entry := userHomeEntry{Username: parts[0], HomeDir: parts[5]}
entries = append(entries, entry)
}
return entries
Expand Down
44 changes: 44 additions & 0 deletions commands/audit_enum_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build !windows
// +build !windows

// Copyright 2026 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package commands

import (
"fmt"
)

// enumerateUserHomeDirs returns the list of user home directories by reading
// /etc/passwd. This is the Unix implementation.
func (a *AuditCmd) enumerateUserHomeDirs() ([]userHomeEntry, error) {
passwdPath := "/etc/passwd"
exists, err := a.Fs.Exists(passwdPath)
if err != nil {
return nil, fmt.Errorf("failed to check /etc/passwd: %w", err)
}
if !exists {
return nil, fmt.Errorf("/etc/passwd not found (needed to enumerate user home policies)")
}

etcPasswdContent, err := a.Fs.ReadFile(passwdPath)
if err != nil {
return nil, fmt.Errorf("failed to read /etc/passwd: %w", err)
}

return getHomeDirsFromEtcPasswd(string(etcPasswdContent)), nil
}
65 changes: 65 additions & 0 deletions commands/audit_enum_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//go:build !windows
// +build !windows

// Copyright 2026 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package commands

import (
"bytes"
"testing"

"github.com/openpubkey/opkssh/policy/files"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
)

func TestEnumerateUserHomeDirs_Unix(t *testing.T) {
t.Parallel()

vfs := afero.NewMemMapFs()
passwdContent := "root:x:0:0:root:/root:/bin/bash\nalice:x:1000:1000::/home/alice:/bin/sh\n"
require.NoError(t, afero.WriteFile(vfs, "/etc/passwd", []byte(passwdContent), 0o644))

cmd := &AuditCmd{
Fs: files.NewFileSystem(vfs, files.WithCmdRunner(func(string, ...string) ([]byte, error) { return nil, nil })),
Out: &bytes.Buffer{},
}

entries, err := cmd.enumerateUserHomeDirs()
require.NoError(t, err)
require.Len(t, entries, 2)
require.Equal(t, "root", entries[0].Username)
require.Equal(t, "/root", entries[0].HomeDir)
require.Equal(t, "alice", entries[1].Username)
require.Equal(t, "/home/alice", entries[1].HomeDir)
}

func TestEnumerateUserHomeDirs_Unix_MissingPasswd(t *testing.T) {
t.Parallel()

vfs := afero.NewMemMapFs()

cmd := &AuditCmd{
Fs: files.NewFileSystem(vfs, files.WithCmdRunner(func(string, ...string) ([]byte, error) { return nil, nil })),
Out: &bytes.Buffer{},
}

_, err := cmd.enumerateUserHomeDirs()
require.Error(t, err)
require.Contains(t, err.Error(), "/etc/passwd not found")
}
26 changes: 26 additions & 0 deletions commands/audit_enum_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//go:build windows
// +build windows

// Copyright 2026 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package commands

// enumerateUserHomeDirs returns the list of user home directories by reading
// the Windows registry ProfileList. This is the Windows implementation.
func (a *AuditCmd) enumerateUserHomeDirs() ([]userHomeEntry, error) {
return getHomeDirsFromProfileList()
}
49 changes: 49 additions & 0 deletions commands/audit_enum_windows_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//go:build windows
// +build windows

// Copyright 2026 OpenPubkey
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// SPDX-License-Identifier: Apache-2.0

package commands

import (
"bytes"
"testing"

"github.com/openpubkey/opkssh/policy/files"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
)

func TestEnumerateUserHomeDirs_Windows(t *testing.T) {
t.Parallel()

vfs := afero.NewMemMapFs()
cmd := &AuditCmd{
Fs: files.NewFileSystem(vfs, files.WithCmdRunner(func(string, ...string) ([]byte, error) { return nil, nil })),
Out: &bytes.Buffer{},
}

entries, err := cmd.enumerateUserHomeDirs()
require.NoError(t, err)
// On a running Windows machine there should be at least one user profile
require.NotEmpty(t, entries, "expected at least one user profile from Windows registry")

for _, e := range entries {
require.NotEmpty(t, e.Username, "username should not be empty")
require.NotEmpty(t, e.HomeDir, "home directory should not be empty")
}
}
Loading
Loading