diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..d811b65 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,26 @@ +version: 2 +updates: + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "monthly" + day: friday + time: "08:00" + labels: + - "dependencies" + commit-message: + prefix: "chore: " + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + day: friday + time: "08:00" + labels: + - "dependencies" + commit-message: + prefix: "chore: " + groups: + experimental-golang-deps: + patterns: + - "golang.org/x/*" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..5bef6ec --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,14 @@ +# Summary + +...enter summary here... + +## Notable Changes + +- ...enter notable changes here... +- ...enter notable changes here... + +## Change Type + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) diff --git a/.github/workflows/codeql.yaml b/.github/workflows/codeql.yaml new file mode 100644 index 0000000..0ad46e6 --- /dev/null +++ b/.github/workflows/codeql.yaml @@ -0,0 +1,38 @@ +name: CodeQL + +on: + push: + branches: + - main + pull_request: + + schedule: + - cron: "00 5 * * SAT" + +jobs: + codeql: + permissions: + actions: read + contents: read + security-events: write + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "^1.24.3" + + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: go + + - name: Autobuild + uses: github/codeql-action/autobuild@v3 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/validate.yaml b/.github/workflows/validate.yaml new file mode 100644 index 0000000..400d51b --- /dev/null +++ b/.github/workflows/validate.yaml @@ -0,0 +1,72 @@ +name: Validate + +on: + push: + branches: + - main + pull_request: + +jobs: + validate: + permissions: + contents: read # for actions/checkout to fetch code + security-events: write # for github/codeql-action/upload-sarif to upload SARIF results + runs-on: ubuntu-latest + steps: + - name: Checkout Source + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "^1.24.3" + + - name: Init project + run: | + go mod tidy + go generate ./... + + # ____ _ _ + # / ___| ___ ___ _ _ _ __(_) |_ _ _ + # \___ \ / _ \/ __| | | | '__| | __| | | | + # ___) | __/ (__| |_| | | | | |_| |_| | + # |____/ \___|\___|\__,_|_| |_|\__|\__, | + # |___/ + - name: Run Gosec Security Scanner + uses: securego/gosec@master + with: + args: "-no-fail -fmt sarif -out results.sarif ./..." + + - name: Upload SARIF file + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: results.sarif + + # _ _ _ + # | | (_)_ __ | |_ + # | | | | '_ \| __| + # | |___| | | | | |_ + # |_____|_|_| |_|\__| + # + - name: golangci-lint + uses: golangci/golangci-lint-action@v8 + with: + version: v2.1.6 + + # _____ _ + # |_ _|__ ___| |_ + # | |/ _ \/ __| __| + # | | __/\__ \ |_ + # |_|\___||___/\__| + # + - name: Run coverage + # TODO: Add -race flag when the container becomes thread safe + run: go test ./... -coverprofile=coverage.txt -covermode=atomic + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + files: ./coverage.txt + fail_ci_if_error: false diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..b190250 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,124 @@ +version: "2" +linters: + default: none + enable: + - asciicheck + - bidichk + - bodyclose + - cyclop + - decorder + - dupl + - durationcheck + - errcheck + - errname + - errorlint + - exhaustive + - funlen + - ginkgolinter + - gocognit + - goconst + - gocritic + - gocyclo + - gomoddirectives + - gomodguard + - gosec + - govet + - ineffassign + - lll + - loggercheck + - makezero + - nakedret + - nestif + - nilerr + - nilnil + - noctx + - nolintlint + - nosprintfhostport + - predeclared + - reassign + - staticcheck + - tagalign + - testableexamples + - testpackage + - tparallel + - unconvert + - unparam + - unused + - usestdlibvars + - wastedassign + - whitespace + - zerologlint + settings: + cyclop: + max-complexity: 30 + package-average: 10 + errcheck: + check-type-assertions: true + exhaustive: + check: + - switch + - map + funlen: + lines: 100 + statements: 50 + ignore-comments: true + gocognit: + min-complexity: 20 + govet: + disable: + - fieldalignment + enable-all: true + settings: + shadow: + strict: true + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + rules: + - linters: + - godot + source: (noinspection|TODO) + - linters: + - gocritic + source: //noinspection + - linters: + - lll + path: mocks\.go + - linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck + - exhaustive + - gocognit + - errcheck + path: _test\.go + - linters: + - staticcheck + text: SA5011 + - path: (.+)\.go$ + text: declaration of "(err|ctx)" shadows declaration at + - path: (.+)\.go$ + text: G115 + paths: + - third_party$ + - builtin$ + - examples$ +issues: + max-same-issues: 50 +formatters: + enable: + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f49a4e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e5dc424 --- /dev/null +++ b/README.md @@ -0,0 +1,75 @@ +# Vault + +A Go package for secure secret storage with multiple encryption backends. Made for [flow](https://github.com/jahvon/flow) but can be used independently. + +## Quick Start + +```go +package main + +import ( + "fmt" + "github.com/jahvon/vault" +) + +func main() { + // Create a new AES vault + v, err := vault.New("my-vault", + vault.WithProvider(vault.ProviderTypeAES256), + vault.WithLocalPath("/path/to/vault/storage"), + vault.WithAESKeyFromEnv("VAULT_KEY"), + ) + if err != nil { + panic(err) + } + defer v.Close() + + // Store a secret + secret := vault.NewSecretValue([]byte("my-secret-value")) + err = v.SetSecret("api-key", secret) + if err != nil { + panic(err) + } + + // Retrieve a secret + retrieved, err := v.GetSecret("api-key") + if err != nil { + panic(err) + } + fmt.Println("Secret:", retrieved.PlainTextString()) +} +``` + +## Providers + +### AES256 Provider + +Symmetric encryption using AES-256. Best for when you want a single encryption key shared across users / systems. + +**Key Generation:** +```go +key, err := vault.GenerateEncryptionKey() +if err != nil { + panic(err) +} +// Store this key securely and configure vault to use it +``` + +### Age Provider + +Asymmetric encryption using the [age encryption format](https://github.com/FiloSottile/age). Best for when you may have multiple users or need the ability to add/remove recipients. + +**Key Generation:** +```bash +# Generate age key pair - see https://github.com/FiloSottile/age for details +age-keygen -o key.txt +# Public key: age1ql3blv6a5y... +# Private key in key.txt +``` + +## Encrypted Files + +Both vault types create a single encrypted file at the specified path: + +- **AES256**: `vault-{id}.enc` +- **Age**: `vault-{id}.age` diff --git a/aes.go b/aes.go new file mode 100644 index 0000000..f2eccfd --- /dev/null +++ b/aes.go @@ -0,0 +1,274 @@ +package vault + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "gopkg.in/yaml.v3" + + "github.com/jahvon/vault/crypto" +) + +const ( + aesCurrentVaultVersion = 1 + aesVaultFileExt = "enc" +) + +// AESState represents the state of the local AES256 vault. +type AESState struct { + Metadata `yaml:"metadata"` + + Version int `json:"version"` + ID string `yaml:"id"` + Secrets map[string]string `yaml:"secrets"` +} + +// AES256Vault manages operations on an instance of a local vault backed by AES256 symmetric encryption. +type AES256Vault struct { + mu sync.RWMutex + id string + fullPath string + + state *AESState + resolver *KeyResolver + dek string +} + +// GenerateEncryptionKey generates a new AES encryption key +func GenerateEncryptionKey() (string, error) { + return crypto.GenerateKey() +} + +// DeriveEncryptionKey derives an AES encryption key from a passphrase +func DeriveEncryptionKey(passphrase, sal string) (string, string, error) { + key, salt, err := crypto.DeriveKey([]byte(passphrase), []byte(sal)) + if err != nil { + return "", "", fmt.Errorf("failed to derive encryption key: %w", err) + } + return key, salt, nil +} + +// ValidateEncryptionKey checks if a key is valid by attempting to encrypt/decrypt test data +func ValidateEncryptionKey(key string) error { + testData := "test-validation-data" + encrypted, err := crypto.EncryptValue(key, testData) + if err != nil { + return fmt.Errorf("key validation failed during encryption: %w", err) + } + + decrypted, err := crypto.DecryptValue(key, encrypted) + if err != nil { + return fmt.Errorf("key validation failed during decryption: %w", err) + } + + if decrypted != testData { + return fmt.Errorf("key validation failed: decrypted data does not match") + } + + return nil +} + +func NewAES256Vault(cfg *Config) (*AES256Vault, error) { + if cfg.Aes == nil { + return nil, fmt.Errorf("AES configuration is required") + } + + path := filepath.Join( + filepath.Clean(cfg.Aes.StoragePath), + filepath.Clean(fmt.Sprintf("%s-%s.%s", vaultFileBase, cfg.ID, aesVaultFileExt)), + ) + + vault := &AES256Vault{ + id: cfg.ID, + fullPath: path, + resolver: NewKeyResolver(cfg.Aes.KeySource), + } + + if err := vault.load(); err != nil { + return nil, fmt.Errorf("failed to load vault: %w", err) + } + + if vault.state == nil { + if err := vault.init(); err != nil { + return nil, fmt.Errorf("failed to initialize vault: %w", err) + } + } + + return vault, nil +} + +func (v *AES256Vault) init() error { + keys, err := v.resolver.ResolveKeys() + if err != nil { + return fmt.Errorf("no encryption key available for new vault: %w", err) + } + v.dek = keys[0] + + now := time.Now() + v.state = &AESState{ + Version: aesCurrentVaultVersion, + ID: v.id, + Metadata: Metadata{ + Created: now, + LastModified: now, + }, + Secrets: make(map[string]string), + } + + return v.save() +} + +// load retrieves the AESState from the vault file, decrypts it, and unmarshals it into an AESState struct. +func (v *AES256Vault) load() error { + data, err := os.ReadFile(filepath.Clean(v.fullPath)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return fmt.Errorf("%w: failed to read vault file %s: %w", ErrVaultNotFound, v.fullPath, err) + } + + if len(data) == 0 { + return nil + } + + // try to decrypt the vault file using available keys + dataStr, key, err := v.resolver.TryDecrypt(string(data)) + if err != nil { + return err + } + v.dek = key + + var state AESState + if err := yaml.Unmarshal([]byte(dataStr), &state); err != nil { + return fmt.Errorf("failed to unmarshal vault state: %w", err) + } + v.state = &state + return nil +} + +// save encrypts and writes the vault contents to disk +func (v *AES256Vault) save() error { + if v.state == nil { + return nil + } + + if v.dek == "" { + return fmt.Errorf("no encryption key available for saving") + } + + v.state.LastModified = time.Now() + data, err := yaml.Marshal(v.state) + if err != nil { + return fmt.Errorf("failed to marshal vault state: %w", err) + } + encryptedDataStr, err := crypto.EncryptValue(v.dek, string(data)) + if err != nil { + return fmt.Errorf("failed to encrypt vault state: %w", err) + } + + // write to the file atomically + if err := os.MkdirAll(filepath.Dir(v.fullPath), 0750); err != nil { + return fmt.Errorf("failed to create vault directory: %w", err) + } + tempFile := v.fullPath + ".tmp" + if err := os.WriteFile(tempFile, []byte(encryptedDataStr), 0600); err != nil { + return fmt.Errorf("failed to write temp vault file: %w", err) + } + + if err := os.Rename(tempFile, v.fullPath); err != nil { + _ = os.Remove(tempFile) + return fmt.Errorf("failed to move vault file: %w", err) + } + + return nil +} + +func (v *AES256Vault) ID() string { + return v.id +} + +func (v *AES256Vault) Metadata() Metadata { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.state == nil { + return Metadata{} + } + return v.state.Metadata +} + +func (v *AES256Vault) GetSecret(key string) (Secret, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + value, exists := v.state.Secrets[key] + if !exists { + return nil, ErrSecretNotFound + } + + return NewSecretValue([]byte(value)), nil +} + +func (v *AES256Vault) SetSecret(key string, secret Secret) error { + v.mu.Lock() + defer v.mu.Unlock() + + if err := ValidateSecretKey(key); err != nil { + return err + } + + if v.state.Secrets == nil { + v.state.Secrets = make(map[string]string) + } + + v.state.Secrets[key] = secret.PlainTextString() + return v.save() +} + +func (v *AES256Vault) DeleteSecret(key string) error { + v.mu.Lock() + defer v.mu.Unlock() + + _, exists := v.state.Secrets[key] + if !exists { + return ErrSecretNotFound + } + + delete(v.state.Secrets, key) + return v.save() +} + +func (v *AES256Vault) ListSecrets() ([]string, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + keys := make([]string, 0, len(v.state.Secrets)) + for k := range v.state.Secrets { + keys = append(keys, k) + } + return keys, nil +} + +func (v *AES256Vault) HasSecret(key string) (bool, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + _, exists := v.state.Secrets[key] + return exists, nil +} + +func (v *AES256Vault) Close() error { + // clear the secret state from memory + v.mu.Lock() + defer v.mu.Unlock() + + v.dek = "" + v.state = nil + + return nil +} diff --git a/aes_key.go b/aes_key.go new file mode 100644 index 0000000..6586bcb --- /dev/null +++ b/aes_key.go @@ -0,0 +1,91 @@ +package vault + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/jahvon/vault/crypto" +) + +type KeyResolver struct { + sources []KeySource +} + +func NewKeyResolver(sources []KeySource) *KeyResolver { + if len(sources) == 0 { + sources = []KeySource{ + {Type: envSource, Name: DefaultVaultKeyEnv}, + } + } + return &KeyResolver{ + sources: sources, + } +} + +func (r *KeyResolver) ResolveKeys() ([]string, error) { + var keys []string + + for _, source := range r.sources { + switch source.Type { + case envSource: + if key := r.fromEnvironment(source.Name); key != "" { + keys = append(keys, key) + } + case fileSource: + if key, err := r.fromFile(source.Path); err == nil && key != "" { + keys = append(keys, key) + } + } + } + + if len(keys) == 0 { + return nil, fmt.Errorf("%w: no encryption keys found", ErrNoAccess) + } + + return keys, nil +} + +func (r *KeyResolver) TryDecrypt(encryptedData string) (string, string, error) { + keys, err := r.ResolveKeys() + if err != nil { + return "", "", err + } + + for _, key := range keys { + decryptedData, err := crypto.DecryptValue(key, encryptedData) + if err != nil { + continue // try the next key + } + return decryptedData, key, nil + } + + return "", "", fmt.Errorf("%w: failed to decrypt data with any available key", ErrDecryptionFailed) +} + +func (r *KeyResolver) fromEnvironment(envVar string) string { + if envVar == "" { + envVar = DefaultVaultKeyEnv + } + + return os.Getenv(envVar) +} + +func (r *KeyResolver) fromFile(path string) (string, error) { + if path == "" { + return "", fmt.Errorf("key file path cannot be empty") + } + + expandedPath, err := expandPath(path) + if err != nil { + return "", fmt.Errorf("failed to expand key file path %s: %w", path, err) + } + + keyBytes, err := os.ReadFile(filepath.Clean(expandedPath)) + if err != nil { + return "", fmt.Errorf("failed to read key file %s: %w", expandedPath, err) + } + + return strings.TrimSpace(string(keyBytes)), nil +} diff --git a/aes_test.go b/aes_test.go new file mode 100644 index 0000000..66c025c --- /dev/null +++ b/aes_test.go @@ -0,0 +1,354 @@ +package vault_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/jahvon/vault" + "github.com/jahvon/vault/crypto" +) + +func TestAESKeyGeneration(t *testing.T) { + key1, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + if key1 == "" { + t.Error("Generated key should not be empty") + } + + key2, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate second key: %v", err) + } + if key1 == key2 { + t.Error("Generated keys should be unique") + } + + err = vault.ValidateEncryptionKey(key1) + if err != nil { + t.Errorf("Valid key failed validation: %v", err) + } + + err = vault.ValidateEncryptionKey("invalid-key") + if err == nil { + t.Error("Invalid key should fail validation") + } +} + +func TestAESKeyResolver(t *testing.T) { + tempDir := t.TempDir() + + // Test key from environment + testKey, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + + t.Setenv("TEST_AES_KEY", testKey) + + resolver := vault.NewKeyResolver([]vault.KeySource{ + {Type: "env", Name: "TEST_AES_KEY"}, + }) + + keys, err := resolver.ResolveKeys() + if err != nil { + t.Fatalf("Failed to resolve keys: %v", err) + } + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + if keys[0] != testKey { + t.Errorf("Expected key %s, got %s", testKey, keys[0]) + } + + // Test key from file + keyFile := filepath.Join(tempDir, "test-key.txt") + err = os.WriteFile(keyFile, []byte(testKey), 0600) + if err != nil { + t.Fatalf("Failed to write key file: %v", err) + } + + resolver = vault.NewKeyResolver([]vault.KeySource{ + {Type: "file", Path: keyFile}, + }) + + keys, err = resolver.ResolveKeys() + if err != nil { + t.Fatalf("Failed to resolve keys from file: %v", err) + } + if len(keys) != 1 { + t.Errorf("Expected 1 key from file, got %d", len(keys)) + } + if keys[0] != testKey { + t.Errorf("Expected key from file %s, got %s", testKey, keys[0]) + } + + // Test multiple sources + resolver = vault.NewKeyResolver([]vault.KeySource{ + {Type: "env", Name: "NONEXISTENT_KEY"}, + {Type: "file", Path: keyFile}, + {Type: "env", Name: "TEST_AES_KEY"}, + }) + + keys, err = resolver.ResolveKeys() + if err != nil { + t.Fatalf("Failed to resolve keys from multiple sources: %v", err) + } + if len(keys) != 2 { // Should find both the file and env key + t.Errorf("Expected 2 keys from multiple sources, got %d", len(keys)) + } +} + +func TestAESKeyResolverDecryption(t *testing.T) { + // Generate test keys + workingKey, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate working key: %v", err) + } + + wrongKey, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate wrong key: %v", err) + } + + // Encrypt test data with working key + testData := "test secret data" + encryptedData, err := crypto.EncryptValue(workingKey, testData) + if err != nil { + t.Fatalf("Failed to encrypt test data: %v", err) + } + + // Set up resolver with wrong key first, then right key + t.Setenv("WRONG_KEY", wrongKey) + t.Setenv("WORKING_KEY", workingKey) + + resolver := vault.NewKeyResolver([]vault.KeySource{ + {Type: "env", Name: "WRONG_KEY"}, + {Type: "env", Name: "WORKING_KEY"}, + }) + + // Test TryDecrypt - should succeed with working key + decryptedData, usedKey, err := resolver.TryDecrypt(encryptedData) + if err != nil { + t.Fatalf("Failed to decrypt with resolver: %v", err) + } + if decryptedData != testData { + t.Errorf("Expected decrypted data %s, got %s", testData, decryptedData) + } + if usedKey != workingKey { + t.Errorf("Expected working key %s, got %s", workingKey, usedKey) + } + + // Test with no working keys + resolver = vault.NewKeyResolver([]vault.KeySource{ + {Type: "env", Name: "WRONG_KEY"}, + }) + + _, _, err = resolver.TryDecrypt(encryptedData) + if err == nil { + t.Error("Expected decryption to fail with wrong key only") + } +} + +func TestAESVaultCreation(t *testing.T) { + tempDir := t.TempDir() + + // Test creating vault without AES config + _, err := vault.NewAES256Vault(&vault.Config{ + ID: "test", + Type: vault.ProviderTypeAES256, + }) + if err == nil { + t.Error("Expected error when creating vault without AES config") + } + + // Test creating vault with valid config + testKey, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + + t.Setenv("TEST_VAULT_KEY", testKey) + + config := &vault.Config{ + ID: "test-vault", + Type: vault.ProviderTypeAES256, + Aes: &vault.AesConfig{ + StoragePath: tempDir, + KeySource: []vault.KeySource{ + {Type: "env", Name: "TEST_VAULT_KEY"}, + }, + }, + } + + v, err := vault.NewAES256Vault(config) + if err != nil { + t.Fatalf("Failed to create AES vault: %v", err) + } + defer v.Close() + + if v.ID() != "test-vault" { + t.Errorf("Expected vault ID 'test-vault', got %s", v.ID()) + } + + // Verify vault file was created + expectedFile := filepath.Join(tempDir, "vault-test-vault.enc") + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Errorf("Expected vault file %s was not created", expectedFile) + } +} + +func TestAESVaultKeyResolution(t *testing.T) { + tempDir := t.TempDir() + + // Create vault with first key + key1, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate key1: %v", err) + } + + t.Setenv("VAULT_KEY_1", key1) + + config := &vault.Config{ + ID: "test-multi-key", + Type: vault.ProviderTypeAES256, + Aes: &vault.AesConfig{ + StoragePath: tempDir, + KeySource: []vault.KeySource{ + {Type: "env", Name: "VAULT_KEY_1"}, + }, + }, + } + + vault1, err := vault.NewAES256Vault(config) + if err != nil { + t.Fatalf("Failed to create vault with key1: %v", err) + } + + // Add a secret + secret := vault.NewSecretValue([]byte("test-secret")) + err = vault1.SetSecret("test-key", secret) + if err != nil { + t.Fatalf("Failed to set secret: %v", err) + } + _ = vault1.Close() + + // Now create a new key and try to access vault with multiple key sources + key2, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate key2: %v", err) + } + + t.Setenv("VAULT_KEY_2", key2) + + // Configure vault with wrong key first, then right key + config.Aes.KeySource = []vault.KeySource{ + {Type: "env", Name: "VAULT_KEY_2"}, // Wrong key first + {Type: "env", Name: "VAULT_KEY_1"}, // Right key second + } + + vault2, err := vault.NewAES256Vault(config) + if err != nil { + t.Fatalf("Failed to create vault with multiple keys: %v", err) + } + defer vault2.Close() + + // Should be able to retrieve the secret using the second key + retrievedSecret, err := vault2.GetSecret("test-key") + if err != nil { + t.Fatalf("Failed to get secret with multiple key sources: %v", err) + } + if retrievedSecret.PlainTextString() != "test-secret" { + t.Errorf("Expected 'test-secret', got %s", retrievedSecret.PlainTextString()) + } +} + +func TestAESVaultFileFormat(t *testing.T) { + tempDir := t.TempDir() + + testKey, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + t.Setenv("FILE_FORMAT_KEY", testKey) + + config := &vault.Config{ + ID: "format-test", + Type: vault.ProviderTypeAES256, + Aes: &vault.AesConfig{ + StoragePath: tempDir, + KeySource: []vault.KeySource{ + {Type: "env", Name: "FILE_FORMAT_KEY"}, + }, + }, + } + + vault1, err := vault.NewAES256Vault(config) + if err != nil { + t.Fatalf("Failed to create vault: %v", err) + } + + _ = vault1.SetSecret("key1", vault.NewSecretValue([]byte("value1"))) + _ = vault1.SetSecret("key2", vault.NewSecretValue([]byte("value2"))) + _ = vault1.Close() + + vaultFile := filepath.Join(tempDir, "vault-format-test.enc") + data, err := os.ReadFile(vaultFile) + if err != nil { + t.Fatalf("Failed to read vault file: %v", err) + } + + if len(data) == 0 { + t.Error("Vault file should not be empty") + } + + // Verify the file is encrypted (should not contain plain text) + dataStr := string(data) + if strings.Contains(dataStr, "key1") || + strings.Contains(dataStr, "value1") || + strings.Contains(dataStr, "key2") || + strings.Contains(dataStr, "value2") { + t.Error("Vault file should not contain plain text secrets") + } + + // Verify file can be decrypted by creating new vault + vault2, err := vault.NewAES256Vault(config) + if err != nil { + t.Fatalf("Failed to recreate vault: %v", err) + } + defer vault2.Close() + + secret1, err := vault2.GetSecret("key1") + if err != nil { + t.Fatalf("Failed to decrypt secret: %v", err) + } + if secret1.PlainTextString() != "value1" { + t.Errorf("Expected 'value1', got %s", secret1.PlainTextString()) + } +} + +func TestAESDefaultKeySource(t *testing.T) { + // Test that KeyResolver works with default sources when nil is provided + // This is a behavioral test rather than testing internal implementation + testKey, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + t.Setenv(vault.DefaultVaultKeyEnv, testKey) + + resolver := vault.NewKeyResolver(nil) + keys, err := resolver.ResolveKeys() + if err != nil { + t.Fatalf("Failed to resolve keys with default source: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key from default source, got %d", len(keys)) + } + if keys[0] != testKey { + t.Errorf("Expected key %s, got %s", testKey, keys[0]) + } +} diff --git a/age.go b/age.go new file mode 100644 index 0000000..43b842b --- /dev/null +++ b/age.go @@ -0,0 +1,325 @@ +package vault + +import ( + "bytes" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "filippo.io/age" +) + +const ( + ageCurrentVaultVersion = 1 + ageVaultFileExt = "age" +) + +// AgeState represents the state of the local age vault +type AgeState struct { + Metadata `json:"metadata"` + + Version int `json:"version"` + ID string `json:"id"` + Recipients []string `json:"recipients"` + Secrets map[string]string `json:"secrets"` +} + +// AgeVault manages operations on an instance of a local vault backed by age encryption. +type AgeVault struct { + mu sync.RWMutex + id string + fullPath string + + cfg *AgeConfig + state *AgeState + resolver *IdentityResolver + + identities []age.Identity + recipients []age.Recipient +} + +func NewAgeVault(cfg *Config) (*AgeVault, error) { + if cfg.Age == nil { + return nil, fmt.Errorf("age configuration is required") + } + + path := filepath.Join( + filepath.Clean(cfg.Age.StoragePath), + filepath.Clean(fmt.Sprintf("%s-%s.%s", vaultFileBase, cfg.ID, ageVaultFileExt)), + ) + + vault := &AgeVault{ + mu: sync.RWMutex{}, + fullPath: path, + id: cfg.ID, + cfg: cfg.Age, + resolver: NewIdentityResolver(cfg.Age.IdentitySources), + } + + ids, err := vault.resolver.ResolveIdentities() + if err != nil { + return nil, fmt.Errorf("failed to resolve identities: %w", err) + } + vault.identities = ids + + if err := vault.load(); err != nil { + return nil, fmt.Errorf("failed to load vault: %w", err) + } + + if vault.state == nil { + if err := vault.init(); err != nil { + return nil, fmt.Errorf("failed to initialize vault: %w", err) + } + } + + return vault, nil +} + +func (v *AgeVault) init() error { + now := time.Now() + v.state = &AgeState{ + Version: ageCurrentVaultVersion, + ID: v.id, + Metadata: Metadata{ + Created: now, + LastModified: now, + }, + Recipients: v.cfg.Recipients, + Secrets: make(map[string]string), + } + + for _, recipientKey := range v.cfg.Recipients { + if err := v.addRecipientToState(recipientKey); err != nil { + return fmt.Errorf("failed to add initial recipient %s: %w", recipientKey, err) + } + } + + if len(v.state.Recipients) == 0 { + return fmt.Errorf("no recipients available for encryption, please add at least one recipient") + } + + if err := v.parseRecipients(); err != nil { + return fmt.Errorf("failed to parse recipients: %w", err) + } + + return v.save() +} + +// load reads the vault file and decrypts its contents +func (v *AgeVault) load() error { + data, err := os.ReadFile(v.fullPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read vault file: %w", err) + } + + if len(data) == 0 { + return nil + } + + r, err := age.Decrypt(bytes.NewReader(data), v.identities...) + if err != nil { + return fmt.Errorf("failed to decrypt vault file - do you have the right key?: %w", err) + } + + var state AgeState + if err := json.NewDecoder(r).Decode(&state); err != nil { + return fmt.Errorf("failed to unmarshal vault state: %w", err) + } + + v.state = &state + if err := v.parseRecipients(); err != nil { + return fmt.Errorf("failed to parse recipients: %w", err) + } + + return nil +} + +// save encrypts and writes the vault contents to disk +func (v *AgeVault) save() error { + if v.state == nil { + return nil + } + + if len(v.recipients) == 0 { + return fmt.Errorf("no recipients available for encryption") + } + + v.state.LastModified = time.Now() + data, err := json.Marshal(v.state) + if err != nil { + return fmt.Errorf("failed to marshal vault state: %w", err) + } + + var buf bytes.Buffer + // encrypt the entire file using age + w, err := age.Encrypt(&buf, v.recipients...) + if err != nil { + return fmt.Errorf("failed to create age encryptor: %w", err) + } + if _, err := w.Write(data); err != nil { + return fmt.Errorf("failed to encrypt AESState: %w", err) + } + if err := w.Close(); err != nil { + return fmt.Errorf("failed to finalize encryption: %w", err) + } + + // write to the file atomically + if err := os.MkdirAll(filepath.Dir(v.fullPath), 0750); err != nil { + return fmt.Errorf("failed to create vault directory: %w", err) + } + tempFile := v.fullPath + ".tmp" + if err := os.WriteFile(tempFile, buf.Bytes(), 0600); err != nil { + return fmt.Errorf("failed to write temp vault file: %w", err) + } + + if err := os.Rename(tempFile, v.fullPath); err != nil { + _ = os.Remove(tempFile) + return fmt.Errorf("failed to move vault file: %w", err) + } + + return nil +} + +func (v *AgeVault) ID() string { + return v.id +} + +func (v *AgeVault) Metadata() Metadata { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.state == nil { + return Metadata{} + } + return v.state.Metadata +} + +func (v *AgeVault) GetSecret(key string) (Secret, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + value, exists := v.state.Secrets[key] + if !exists { + return nil, ErrSecretNotFound + } + + return NewSecretValue([]byte(value)), nil +} + +func (v *AgeVault) SetSecret(key string, value Secret) error { + v.mu.Lock() + defer v.mu.Unlock() + + if err := ValidateSecretKey(key); err != nil { + return err + } + + if v.state.Secrets == nil { + v.state.Secrets = make(map[string]string) + } + + v.state.Secrets[key] = value.PlainTextString() + return v.save() +} + +func (v *AgeVault) DeleteSecret(key string) error { + v.mu.Lock() + defer v.mu.Unlock() + + _, exists := v.state.Secrets[key] + if !exists { + return ErrSecretNotFound + } + + delete(v.state.Secrets, key) + return v.save() +} + +func (v *AgeVault) ListSecrets() ([]string, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + keys := make([]string, 0, len(v.state.Secrets)) + for k := range v.state.Secrets { + keys = append(keys, k) + } + return keys, nil +} + +func (v *AgeVault) HasSecret(key string) (bool, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + _, exists := v.state.Secrets[key] + return exists, nil +} + +func (v *AgeVault) Close() error { + // clear the secret state from memory + v.mu.Lock() + defer v.mu.Unlock() + + v.state = nil + v.recipients = nil + v.identities = nil + + return nil +} + +func (v *AgeVault) AddRecipient(publicKey string) error { + v.mu.Lock() + defer v.mu.Unlock() + + if err := v.addRecipientToState(publicKey); err != nil { + return err + } + if err := v.parseRecipients(); err != nil { + return fmt.Errorf("failed to parse recipients: %w", err) + } + + return v.save() +} + +func (v *AgeVault) RemoveRecipient(publicKey string) error { + v.mu.Lock() + defer v.mu.Unlock() + + // Don't allow removing the last recipient + if len(v.state.Recipients) <= 1 { + return fmt.Errorf("cannot remove the last recipient - at least one recipient is required for encryption") + } + + found := false + for i, rec := range v.state.Recipients { + if rec == publicKey { + v.state.Recipients = append(v.state.Recipients[:i], v.state.Recipients[i+1:]...) + found = true + break + } + } + + if !found { + return fmt.Errorf("recipient %s not found", publicKey) + } + + if err := v.parseRecipients(); err != nil { + return fmt.Errorf("failed to parse recipients: %w", err) + } + + return v.save() +} + +func (v *AgeVault) ListRecipients() ([]string, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + recipients := make([]string, len(v.state.Recipients)) + copy(recipients, v.state.Recipients) // prevent modification of internal state + return recipients, nil +} diff --git a/local_identity.go b/age_identity.go similarity index 75% rename from local_identity.go rename to age_identity.go index 0308c68..4eb4029 100644 --- a/local_identity.go +++ b/age_identity.go @@ -8,10 +8,6 @@ import ( "filippo.io/age" ) -var ( - DefaultVaultKeyEnv = "AGE_VAULT_KEY" -) - type IdentityResolver struct { sources []IdentitySource } @@ -19,7 +15,7 @@ type IdentityResolver struct { func NewIdentityResolver(sources []IdentitySource) *IdentityResolver { if len(sources) == 0 { sources = []IdentitySource{ - {Type: "env", Name: DefaultVaultKeyEnv}, + {Type: envSource, Name: DefaultVaultKeyEnv}, } } return &IdentityResolver{sources: sources} @@ -30,11 +26,11 @@ func (r *IdentityResolver) ResolveIdentities() ([]age.Identity, error) { for _, source := range r.sources { switch source.Type { - case "env": + case envSource: if id := r.fromEnvironment(source.Name); id != nil { identities = append(identities, id) } - case "file": + case fileSource: if id, err := r.fromFile(source.Path); err != nil { return nil, fmt.Errorf("failed to read identity from file %s: %w", source.Path, err) } else if id != nil { @@ -44,7 +40,7 @@ func (r *IdentityResolver) ResolveIdentities() ([]age.Identity, error) { } if len(identities) == 0 { - return nil, fmt.Errorf("no valid identities found") + return nil, fmt.Errorf("%w: no valid identities found", ErrNoAccess) } return identities, nil @@ -73,7 +69,11 @@ func (r *IdentityResolver) fromFile(path string) (age.Identity, error) { return nil, fmt.Errorf("identity file path cannot be empty") } - expandedPath := expandPath(path) + expandedPath, err := expandPath(path) + if err != nil { + return nil, fmt.Errorf("failed to expand identity file path %s: %w", path, err) + } + keyBytes, err := os.ReadFile(expandedPath) if err != nil { return nil, fmt.Errorf("failed to read identity file %s: %w", expandedPath, err) @@ -87,29 +87,29 @@ func (r *IdentityResolver) fromFile(path string) (age.Identity, error) { return identity, nil } -func (v *LocalVault) addRecipientToState(publicKey string) error { +func (v *AgeVault) addRecipientToState(publicKey string) error { _, err := age.ParseX25519Recipient(publicKey) if err != nil { - return fmt.Errorf("invalid recipient key: %w", err) + return fmt.Errorf("%w: invalid recipient key: %w", ErrInvalidRecipient, err) } - // for _, existing := range v.state.Recipients { - // if existing == publicKey { - // return fmt.Errorf("recipient already exists") - // } - // } + for _, existing := range v.state.Recipients { + if existing == publicKey { + return nil + } + } v.state.Recipients = append(v.state.Recipients, publicKey) return nil } -func (v *LocalVault) parseRecipients() error { +func (v *AgeVault) parseRecipients() error { v.recipients = make([]age.Recipient, 0, len(v.state.Recipients)) for _, recipientStr := range v.state.Recipients { recipient, err := age.ParseX25519Recipient(recipientStr) if err != nil { - return fmt.Errorf("invalid recipient %s: %w", recipientStr, err) + return fmt.Errorf("%w: invalid recipient %s: %w", ErrInvalidRecipient, recipientStr, err) } v.recipients = append(v.recipients, recipient) } diff --git a/age_test.go b/age_test.go new file mode 100644 index 0000000..2595c91 --- /dev/null +++ b/age_test.go @@ -0,0 +1,435 @@ +package vault_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/jahvon/vault" +) + +func TestAgeIdentityResolver(t *testing.T) { + tempDir := t.TempDir() + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + t.Setenv("TEST_AGE_IDENTITY", testIdentity) + + resolver := vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "env", Name: "TEST_AGE_IDENTITY"}, + }) + + identities, err := resolver.ResolveIdentities() + if err != nil { + t.Fatalf("Failed to resolve identities: %v", err) + } + if len(identities) != 1 { + t.Errorf("Expected 1 identity, got %d", len(identities)) + } + + // Test identity from file + keyFile := filepath.Join(tempDir, "test-identity.txt") + err = os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + resolver = vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "file", Path: keyFile}, + }) + + identities, err = resolver.ResolveIdentities() + if err != nil { + t.Fatalf("Failed to resolve identities from file: %v", err) + } + if len(identities) != 1 { + t.Errorf("Expected 1 identity from file, got %d", len(identities)) + } + + // Test multiple sources + resolver = vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "env", Name: "NONEXISTENT_IDENTITY"}, + {Type: "file", Path: keyFile}, + {Type: "env", Name: "TEST_AGE_IDENTITY"}, + }) + + identities, err = resolver.ResolveIdentities() + if err != nil { + t.Fatalf("Failed to resolve identities from multiple sources: %v", err) + } + if len(identities) != 2 { // Should find both the file and env identity + t.Errorf("Expected 2 identities from multiple sources, got %d", len(identities)) + } +} + +func TestAgeIdentityResolverErrors(t *testing.T) { + // Test invalid identity + t.Setenv("INVALID_AGE_IDENTITY", "not-a-valid-age-key") + + resolver := vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "env", Name: "INVALID_AGE_IDENTITY"}, + }) + + identities, err := resolver.ResolveIdentities() + if err == nil { + t.Error("Expected error when no valid identities found") + } + if len(identities) != 0 { + t.Errorf("Expected 0 identities for invalid key, got %d", len(identities)) + } + + // Test nonexistent file + resolver = vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "file", Path: "/nonexistent/path/key.txt"}, + }) + + _, err = resolver.ResolveIdentities() + if err == nil { + t.Error("Expected error for nonexistent file") + } + + // Test empty file path + resolver = vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "file", Path: ""}, + }) + + _, err = resolver.ResolveIdentities() + if err == nil { + t.Error("Expected error for empty file path") + } + + // Test no valid identities + resolver = vault.NewIdentityResolver([]vault.IdentitySource{ + {Type: "env", Name: "NONEXISTENT_KEY"}, + }) + + _, err = resolver.ResolveIdentities() + if err == nil { + t.Error("Expected error when no valid identities found") + } +} + +func TestAgeVaultCreation(t *testing.T) { + tempDir := t.TempDir() + + // Test creating vault without Age config + _, err := vault.NewAgeVault(&vault.Config{ + ID: "test", + Type: vault.ProviderTypeAge, + }) + if err == nil { + t.Error("Expected error when creating vault without Age config") + } + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + testRecipient := "age1wnhg53pg2qfsfxwvxvlg6pygw5uzwcyhj2dqhg0k83fvjexf9pzsxqdvs0" + keyFile := filepath.Join(tempDir, "test-key.txt") + err = os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + config := &vault.Config{ + ID: "test-age-vault", + Type: vault.ProviderTypeAge, + Age: &vault.AgeConfig{ + StoragePath: tempDir, + IdentitySources: []vault.IdentitySource{ + {Type: "file", Path: keyFile}, + }, + Recipients: []string{testRecipient}, + }, + } + + v, err := vault.NewAgeVault(config) + if err != nil { + t.Fatalf("Failed to create Age vault: %v", err) + } + defer v.Close() + + if v.ID() != "test-age-vault" { + t.Errorf("Expected vault ID 'test-age-vault', got %s", v.ID()) + } + + // Verify vault file was created + expectedFile := filepath.Join(tempDir, "vault-test-age-vault.age") + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Errorf("Expected vault file %s was not created", expectedFile) + } +} + +func TestAgeVaultRecipientManagement(t *testing.T) { + tempDir := t.TempDir() + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + testRecipient1 := "age1wnhg53pg2qfsfxwvxvlg6pygw5uzwcyhj2dqhg0k83fvjexf9pzsxqdvs0" + testRecipient2 := "age1u7rkgxlu26y68m3ky0aesxtls9g33zy5zcy0wuehtwua6lssmpus4xszw6" + keyFile := filepath.Join(tempDir, "test-key.txt") + err := os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + config := &vault.Config{ + ID: "recipient-test", + Type: vault.ProviderTypeAge, + Age: &vault.AgeConfig{ + StoragePath: tempDir, + IdentitySources: []vault.IdentitySource{ + {Type: "file", Path: keyFile}, + }, + Recipients: []string{testRecipient1}, + }, + } + + v, err := vault.NewAgeVault(config) + if err != nil { + t.Fatalf("Failed to create Age vault: %v", err) + } + defer v.Close() + + // Test initial recipients + recipients, err := v.ListRecipients() + if err != nil { + t.Fatalf("Failed to list recipients: %v", err) + } + if len(recipients) != 1 { + t.Errorf("Expected 1 initial recipient, got %d", len(recipients)) + } + if recipients[0] != testRecipient1 { + t.Errorf("Expected recipient %s, got %s", testRecipient1, recipients[0]) + } + + // Test adding recipient + err = v.AddRecipient(testRecipient2) + if err != nil { + t.Fatalf("Failed to add recipient: %v", err) + } + + recipients, err = v.ListRecipients() + if err != nil { + t.Fatalf("Failed to list recipients after add: %v", err) + } + if len(recipients) != 2 { + t.Errorf("Expected 2 recipients after add, got %d", len(recipients)) + } + + // Test adding duplicate recipient (should not fail) + err = v.AddRecipient(testRecipient1) + if err != nil { + t.Fatalf("Failed to add duplicate recipient: %v", err) + } + + recipients, err = v.ListRecipients() + if err != nil { + t.Fatalf("Failed to list recipients after duplicate add: %v", err) + } + if len(recipients) != 2 { // Should still be 2, duplicate doesn't increase count + t.Errorf("Expected 2 recipients after duplicate add, got %d", len(recipients)) + } + + // Test removing recipient (should succeed now that we have 2) + err = v.RemoveRecipient(testRecipient2) + if err != nil { + t.Fatalf("Failed to remove recipient: %v", err) + } + + recipients, err = v.ListRecipients() + if err != nil { + t.Fatalf("Failed to list recipients after remove: %v", err) + } + if len(recipients) != 1 { + t.Errorf("Expected 1 recipient after remove, got %d", len(recipients)) + } + + // Test removing the last recipient (should fail) + err = v.RemoveRecipient(testRecipient1) + if err == nil { + t.Error("Expected error when removing the last recipient") + } + + // Test removing nonexistent recipient + err = v.RemoveRecipient("age1nonexistent123456789") + if err == nil { + t.Error("Expected error when removing nonexistent recipient") + } +} + +func TestAgeVaultInvalidRecipient(t *testing.T) { + tempDir := t.TempDir() + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + keyFile := filepath.Join(tempDir, "test-key.txt") + err := os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + config := &vault.Config{ + ID: "invalid-recipient-test", + Type: vault.ProviderTypeAge, + Age: &vault.AgeConfig{ + StoragePath: tempDir, + IdentitySources: []vault.IdentitySource{ + {Type: "file", Path: keyFile}, + }, + Recipients: []string{"invalid-recipient-key"}, + }, + } + + // Should fail during vault creation due to invalid recipient + _, err = vault.NewAgeVault(config) + if err == nil { + t.Error("Expected error when creating vault with invalid recipient") + } +} + +func TestAgeVaultFileFormat(t *testing.T) { + tempDir := t.TempDir() + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + testRecipient := "age1wnhg53pg2qfsfxwvxvlg6pygw5uzwcyhj2dqhg0k83fvjexf9pzsxqdvs0" + + keyFile := filepath.Join(tempDir, "test-key.txt") + err := os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + config := &vault.Config{ + ID: "format-test", + Type: vault.ProviderTypeAge, + Age: &vault.AgeConfig{ + StoragePath: tempDir, + IdentitySources: []vault.IdentitySource{ + {Type: "file", Path: keyFile}, + }, + Recipients: []string{testRecipient}, + }, + } + + vault1, err := vault.NewAgeVault(config) + if err != nil { + t.Fatalf("Failed to create vault: %v", err) + } + + _ = vault1.SetSecret("key1", vault.NewSecretValue([]byte("value1"))) + _ = vault1.SetSecret("key2", vault.NewSecretValue([]byte("value2"))) + _ = vault1.Close() + + // Verify the encrypted file exists and has content + vaultFile := filepath.Join(tempDir, "vault-format-test.age") + data, err := os.ReadFile(vaultFile) + if err != nil { + t.Fatalf("Failed to read vault file: %v", err) + } + + if len(data) == 0 { + t.Error("Vault file should not be empty") + } + + // Verify the file starts with age format header + dataStr := string(data) + if !strings.HasPrefix(dataStr, "age-encryption.org/v1") { + t.Error("Age vault file should start with age format header") + } + + // Verify the file is encrypted (should not contain plain text) + if strings.Contains(dataStr, "key1") || + strings.Contains(dataStr, "value1") || + strings.Contains(dataStr, "key2") || + strings.Contains(dataStr, "value2") { + t.Error("Age vault file should not contain plain text secrets") + } + + // Verify file can be decrypted by creating new vault + vault2, err := vault.NewAgeVault(config) + if err != nil { + t.Fatalf("Failed to recreate vault: %v", err) + } + defer vault2.Close() + + secret1, err := vault2.GetSecret("key1") + if err != nil { + t.Fatalf("Failed to decrypt secret: %v", err) + } + if secret1.PlainTextString() != "value1" { + t.Errorf("Expected 'value1', got %s", secret1.PlainTextString()) + } +} + +func TestAgeVaultNoRecipients(t *testing.T) { + tempDir := t.TempDir() + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + keyFile := filepath.Join(tempDir, "test-key.txt") + err := os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + config := &vault.Config{ + ID: "no-recipients-test", + Type: vault.ProviderTypeAge, + Age: &vault.AgeConfig{ + StoragePath: tempDir, + IdentitySources: []vault.IdentitySource{ + {Type: "file", Path: keyFile}, + }, + Recipients: []string{}, + }, + } + + _, err = vault.NewAgeVault(config) + if err == nil { + t.Error("Expected error when creating vault with no recipients") + } +} + +func TestAgeVaultPathExpansion(t *testing.T) { + tempDir := t.TempDir() + + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + relativeKeyFile := "./test-key.txt" + + // Change to temp dir so relative path works + oldDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get working directory: %v", err) + } + defer os.Chdir(oldDir) + + err = os.Chdir(tempDir) + if err != nil { + t.Fatalf("Failed to change to temp directory: %v", err) + } + + err = os.WriteFile(relativeKeyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write identity file: %v", err) + } + + testRecipient := "age1wnhg53pg2qfsfxwvxvlg6pygw5uzwcyhj2dqhg0k83fvjexf9pzsxqdvs0" + config := &vault.Config{ + ID: "path-expansion-test", + Type: vault.ProviderTypeAge, + Age: &vault.AgeConfig{ + StoragePath: tempDir, + IdentitySources: []vault.IdentitySource{ + {Type: "file", Path: relativeKeyFile}, + }, + Recipients: []string{testRecipient}, + }, + } + + v, err := vault.NewAgeVault(config) + if err != nil { + t.Fatalf("Failed to create vault with relative path: %v", err) + } + defer v.Close() + + err = v.SetSecret("test", vault.NewSecretValue([]byte("value"))) + if err != nil { + t.Fatalf("Failed to set secret with relative path identity: %v", err) + } +} diff --git a/cmd/main.go b/cmd/main.go deleted file mode 100644 index 0987d19..0000000 --- a/cmd/main.go +++ /dev/null @@ -1,72 +0,0 @@ -package main - -import ( - "fmt" - "log" - "path/filepath" - - "github.com/jahvon/vault" -) - -func main() { - dir := "/Users/jahvon/workspaces/github.com/jahvon/vault/playground" - fmt.Printf("Testing vault in: %s\n", dir) - - fmt.Println("\n=== Test 1: Create New Vault ===") - vault1, err := vault.New( - "test", - vault.WithProvider(vault.ProviderTypeLocal), - vault.WithRecipients("age1nmkk0tv7ntg5yld0uhxc9f05p0d6zwxcaftxcjvwy82djuuzg96skmuzlk"), - vault.WithLocalPath(dir), - vault.WithLocalIdentityFromFile("/Users/jahvon/workspaces/github.com/jahvon/vault/playground/key.txt"), - ) - if err != nil { - log.Fatal("Failed to create vault:", err) - } - defer vault1.Close() - - fmt.Printf("Created vault with ID: %s\n", vault1.ID()) - - fmt.Println("\n=== Test 2: Set and Get Secrets ===") - - err = vault1.SetSecret("api-key", vault.NewSecretValue([]byte("my-secret-api-key"))) - if err != nil { - log.Fatal("Failed to set secret:", err) - } - fmt.Println("✓ Set api-key") - - err = vault1.SetSecret("db-password", vault.NewSecretValue([]byte("super-secret-password"))) - if err != nil { - log.Fatal("Failed to set db-password:", err) - } - fmt.Println("✓ Set db-password") - - secret, err := vault1.GetSecret("api-key") - if err != nil { - log.Fatal("Failed to get secret:", err) - } - fmt.Printf("✓ Retrieved api-key: %s (masked: %s)\n", secret.PlainTextString(), secret.String()) - - fmt.Println("\n=== Test 3: List Secrets ===") - secrets, err := vault1.ListSecrets() - if err != nil { - log.Fatal("Failed to list secrets:", err) - } - - fmt.Printf("Found %d secrets:\n", len(secrets)) - for _, key := range secrets { - fmt.Printf(" - %s\n", key) - } - - fmt.Println("\n=== Test 4: Verify Encrypted File ===") - vaultFiles, err := filepath.Glob(filepath.Join(dir, "*.age")) - if err != nil { - log.Fatal("Failed to find vault files:", err) - } - - if len(vaultFiles) == 0 { - log.Fatal("No .age vault files found!") - } - - fmt.Printf("✓ Found encrypted vault file: %s\n", vaultFiles[0]) -} diff --git a/config.go b/config.go index 136e1a8..dc1214b 100644 --- a/config.go +++ b/config.go @@ -11,35 +11,42 @@ import ( type ProviderType string const ( - ProviderTypeLocal ProviderType = "local" + ProviderTypeAES256 ProviderType = "aes256" + ProviderTypeAge ProviderType = "age" ProviderTypeExternal ProviderType = "external" ) type Config struct { ID string `json:"id"` Type ProviderType `json:"type"` - Local *LocalConfig `json:"local,omitempty"` + Age *AgeConfig `json:"age,omitempty"` + Aes *AesConfig `json:"aes,omitempty"` External *ExternalConfig `json:"external,omitempty"` } func (c *Config) Validate() error { if c.ID == "" { - return fmt.Errorf("vault ID is required") + return fmt.Errorf("%w: vault ID is required", ErrInvalidConfig) } switch c.Type { - case ProviderTypeLocal: - if c.Local == nil { - return fmt.Errorf("local configuration required for local vault") + case ProviderTypeAge: + if c.Age == nil { + return fmt.Errorf("%w: age configuration required for the age vault provider", ErrInvalidConfig) } - return c.Local.Validate() + return c.Age.Validate() + case ProviderTypeAES256: + if c.Aes == nil { + return fmt.Errorf("%w: aes configuration required for the aes256 vault provider", ErrInvalidConfig) + } + return c.Aes.Validate() case ProviderTypeExternal: if c.External == nil { - return fmt.Errorf("external configuration required for external vault") + return fmt.Errorf("%w: external configuration required for external vault", ErrInvalidConfig) } return c.External.Validate() default: - return fmt.Errorf("unsupported vault type: %s", c.Type) + return fmt.Errorf("%w: unsupported vault type: %s", ErrInvalidConfig, c.Type) } } @@ -50,11 +57,11 @@ func SaveConfigJSON(config Config, path string) error { return fmt.Errorf("failed to marshal config: %w", err) } - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(path), 0750); err != nil { return fmt.Errorf("failed to create config directory: %w", err) } - if err := os.WriteFile(path, data, 0600); err != nil { + if err := os.WriteFile(filepath.Clean(path), data, 0600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } @@ -63,7 +70,7 @@ func SaveConfigJSON(config Config, path string) error { // LoadConfigJSON loads the vault configuration from a file in JSON format func LoadConfigJSON(path string) (Config, error) { - data, err := os.ReadFile(path) + data, err := os.ReadFile(filepath.Clean(path)) if err != nil { return Config{}, fmt.Errorf("failed to read config file: %w", err) } @@ -79,7 +86,7 @@ func LoadConfigJSON(path string) (Config, error) { // IdentitySource represents a source for the local vault identity keys type IdentitySource struct { // Type of identity source - // Must be one of: "env", "file", "ssh-agent" + // Must be one of: "env", "file" Type string `json:"type"` // Path to the identity file (for "file" type) Path string `json:"fullPath,omitempty"` @@ -87,8 +94,8 @@ type IdentitySource struct { Name string `json:"name,omitempty"` } -// LocalConfig contains local (age-based) vault configuration -type LocalConfig struct { +// AgeConfig contains local (age-based) vault configuration +type AgeConfig struct { // Storage location for the vault file StoragePath string `json:"storage_path"` @@ -99,9 +106,63 @@ type LocalConfig struct { Recipients []string `json:"recipients,omitempty"` } -func (c *LocalConfig) Validate() error { +func (c *AgeConfig) Validate() error { if c.StoragePath == "" { - return fmt.Errorf("storage fullPath is required for local vault") + return fmt.Errorf("%w: storage path is required for age vault", ErrInvalidConfig) + } + if len(c.IdentitySources) == 0 { + return fmt.Errorf("%w: at least one identity source is required for age vault", ErrInvalidConfig) + } + for _, source := range c.IdentitySources { + if source.Type != envSource && source.Type != fileSource { + return fmt.Errorf("%w: invalid identity source type: %s", ErrInvalidConfig, source.Type) + } + if source.Type == fileSource && source.Path == "" { + return fmt.Errorf("%w: path is required for file identity source", ErrInvalidConfig) + } + if source.Type == envSource && source.Name == "" { + return fmt.Errorf("%w: name is required for env identity source", ErrInvalidConfig) + } + } + return nil +} + +// KeySource represents a source for the local vault encryption keys +type KeySource struct { + // Type of data encryption key source + // Must be one of: "env", "file" + Type string `json:"type"` + // Path to the identity file (for "file" type) + Path string `json:"fullPath,omitempty"` + // Environment variable name (for "env" type) + Name string `json:"name,omitempty"` +} + +// AesConfig contains local (AES256-based) vault configuration +type AesConfig struct { + // Storage location for the vault file + StoragePath string `json:"storage_path"` + // DEK sources for decryption (in order of preference) + KeySource []KeySource `json:"key_sources,omitempty"` +} + +func (c *AesConfig) Validate() error { + if c.StoragePath == "" { + return fmt.Errorf("%w: storage path is required for AES vault", ErrInvalidConfig) + } + if len(c.KeySource) == 0 { + return fmt.Errorf("%w: at least one key source is required for AES vault", ErrInvalidConfig) + } + for _, source := range c.KeySource { + if source.Type != envSource && source.Type != fileSource { + return fmt.Errorf("%w: invalid key source type: %s", ErrInvalidConfig, source.Type) + } + if source.Type == fileSource && source.Path == "" { + return fmt.Errorf("%w: path is required for file key source", ErrInvalidConfig) + } + if source.Type == envSource && source.Name == "" { + return fmt.Errorf("%w: name is required for env key source", ErrInvalidConfig) + } } return nil } @@ -132,7 +193,7 @@ type ExternalConfig struct { func (c *ExternalConfig) Validate() error { if c.Commands.Get == "" || c.Commands.Set == "" { - return fmt.Errorf("get and set commands are required for external vault") + return fmt.Errorf("%w: get and set commands are required for external vault", ErrInvalidConfig) } return nil } diff --git a/crypto/crypto.go b/crypto/crypto.go new file mode 100644 index 0000000..e30948d --- /dev/null +++ b/crypto/crypto.go @@ -0,0 +1,122 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + + "golang.org/x/crypto/scrypt" +) + +// GenerateKey generates a random 32 byte key and returns it as a base64 encoded string. +func GenerateKey() (string, error) { + key := make([]byte, 32) + _, err := rand.Read(key) + if err != nil { + return "", fmt.Errorf("error reading random bytes: %w", err) + } + return EncodeValue(key), nil +} + +// DeriveKey derives a 32 byte key from the provided password and salt and returns +// the key and salt as base64 encoded strings. +// If salt is nil, a random salt will be generated. +func DeriveKey(password, salt []byte) (string, string, error) { + if salt == nil { + salt = make([]byte, 32) + if _, err := rand.Read(salt); err != nil { + return "", "", err + } + } + + key, err := scrypt.Key(password, salt, 1048576, 8, 1, 32) + if err != nil { + return "", "", err + } + + return EncodeValue(key), EncodeValue(salt), nil +} + +// EncodeValue encodes a byte slice as a base64 encoded string. +func EncodeValue(b []byte) string { + return base64.StdEncoding.EncodeToString(b) +} + +// DecodeValue decodes a base64 encoded string into a byte slice. +func DecodeValue(s string) ([]byte, error) { + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return nil, err + } + return data, nil +} + +// EncryptValue encrypts a string using AES-256-GCM and returns the encrypted value as a base64 encoded string. +// The encryption key used for encryption must be a base64 encoded string. +func EncryptValue(encryptionKey string, text string) (string, error) { + decodedMasterKey, err := DecodeValue(encryptionKey) + if err != nil { + return "", fmt.Errorf("error decoding master key: %w", err) + } + block, err := aes.NewCipher(decodedMasterKey) + if err != nil { + return "", fmt.Errorf("error creating new cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("error creating GCM: %w", err) + } + + plaintext := []byte(text) + // verify that the plaintext is not too long to fit in an int + if len(plaintext) > 64*1024*1024 { + return "", fmt.Errorf("plaintext too long to encrypt") + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("error reading random bytes: %w", err) + } + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return EncodeValue(ciphertext), nil +} + +// DecryptValue decrypts a string using AES-256-GCM and returns the decrypted value as a string. +// The master key used for decryption must be a base64 encoded string. +func DecryptValue(encryptionKey string, text string) (string, error) { + decodedMasterKey, err := DecodeValue(encryptionKey) + if err != nil { + return "", fmt.Errorf("error decoding master key: %w", err) + } + block, err := aes.NewCipher(decodedMasterKey) + if err != nil { + return "", fmt.Errorf("error creating new cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", fmt.Errorf("error creating GCM: %w", err) + } + + ciphertext, err := DecodeValue(text) + if err != nil { + return "", fmt.Errorf("error decoding ciphertext: %w", err) + } + + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", fmt.Errorf("decryption failed: %w", err) + } + + return string(plaintext), nil +} diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go new file mode 100644 index 0000000..7bd48d9 --- /dev/null +++ b/crypto/crypto_test.go @@ -0,0 +1,304 @@ +package crypto_test + +import ( + "strings" + "testing" + + "github.com/jahvon/vault/crypto" +) + +func TestGenerateKey(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + if key == "" { + t.Error("Generated key should not be empty") + } + + decodedKey, err := crypto.DecodeValue(key) + if err != nil { + t.Fatalf("Failed to decode generated key: %v", err) + } + if len(decodedKey) == 0 { + t.Error("Decoded key should not be empty") + } + + // Test uniqueness + key2, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate second key: %v", err) + } + if key == key2 { + t.Error("Generated keys should be unique") + } +} + +func TestDeriveKeyWithProvidedSalt(t *testing.T) { + salt, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate salt: %v", err) + } + decodedSalt, err := crypto.DecodeValue(salt) + if err != nil { + t.Fatalf("Failed to decode salt: %v", err) + } + if len(decodedSalt) == 0 { + t.Error("Decoded salt should not be empty") + } + + inputPassword := []byte("password") + derivedKey, outSalt, err := crypto.DeriveKey(inputPassword, decodedSalt) + if err != nil { + t.Fatalf("Failed to derive key: %v", err) + } + if derivedKey == "" { + t.Error("Derived key should not be empty") + } + if outSalt != salt { + t.Errorf("Output salt should equal input salt, got %s, expected %s", outSalt, salt) + } + + decodedDerivedKey, err := crypto.DecodeValue(derivedKey) + if err != nil { + t.Fatalf("Failed to decode derived key: %v", err) + } + if len(decodedDerivedKey) == 0 { + t.Error("Decoded derived key should not be empty") + } +} + +func TestDeriveKeyWithoutSalt(t *testing.T) { + inputPassword := []byte("password") + derivedKey, outSalt, err := crypto.DeriveKey(inputPassword, nil) + if err != nil { + t.Fatalf("Failed to derive key without salt: %v", err) + } + if derivedKey == "" { + t.Error("Derived key should not be empty") + } + if outSalt == "" { + t.Error("Generated salt should not be empty") + } + + decodedDerivedKey, err := crypto.DecodeValue(derivedKey) + if err != nil { + t.Fatalf("Failed to decode derived key: %v", err) + } + if len(decodedDerivedKey) == 0 { + t.Error("Decoded derived key should not be empty") + } + + // Test reproducibility with same salt + decodedSalt, err := crypto.DecodeValue(outSalt) + if err != nil { + t.Fatalf("Failed to decode output salt: %v", err) + } + + derivedKey2, outSalt2, err := crypto.DeriveKey(inputPassword, decodedSalt) + if err != nil { + t.Fatalf("Failed to derive key with same salt: %v", err) + } + if derivedKey != derivedKey2 { + t.Error("Keys derived with same password and salt should be identical") + } + if outSalt != outSalt2 { + t.Error("Output salt should be same when input salt is provided") + } +} + +func TestEncryptDecryptValue(t *testing.T) { + masterKey, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate master key: %v", err) + } + + testCases := []string{ + "test value", + "special chars: !@#$%^&*()", + "unicode text: 🔐 secret 🚀", + "", + "very long text " + strings.Repeat("a", 1000), + "multiline\ntext\nwith\nnewlines", + "text\twith\ttabs", + } + + for _, plaintext := range testCases { + t.Run("encrypt_decrypt_"+plaintext[:minInt(10, len(plaintext))], func(t *testing.T) { + encryptedValue, err := crypto.EncryptValue(masterKey, plaintext) + if err != nil { + t.Fatalf("Failed to encrypt: %v", err) + } + if encryptedValue == "" { + t.Error("Encrypted value should not be empty") + } + if encryptedValue == plaintext && plaintext != "" { + t.Error("Encrypted value should not equal plaintext") + } + + decryptedValue, err := crypto.DecryptValue(masterKey, encryptedValue) + if err != nil { + t.Fatalf("Failed to decrypt: %v", err) + } + if decryptedValue != plaintext { + t.Errorf("Decrypted value doesn't match. Expected %q, got %q", plaintext, decryptedValue) + } + }) + } +} + +func TestEncryptionUniqueness(t *testing.T) { + masterKey, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate master key: %v", err) + } + + plaintext := "same data" + + encrypted1, err := crypto.EncryptValue(masterKey, plaintext) + if err != nil { + t.Fatalf("Failed to encrypt first time: %v", err) + } + + encrypted2, err := crypto.EncryptValue(masterKey, plaintext) + if err != nil { + t.Fatalf("Failed to encrypt second time: %v", err) + } + + if encrypted1 == encrypted2 { + t.Error("Encrypting same data twice should produce different ciphertext") + } + + // Both should decrypt to same value + decrypted1, err := crypto.DecryptValue(masterKey, encrypted1) + if err != nil { + t.Fatalf("Failed to decrypt first ciphertext: %v", err) + } + if decrypted1 != plaintext { + t.Errorf("First decryption should equal plaintext") + } + + decrypted2, err := crypto.DecryptValue(masterKey, encrypted2) + if err != nil { + t.Fatalf("Failed to decrypt second ciphertext: %v", err) + } + if decrypted2 != plaintext { + t.Errorf("Second decryption should equal plaintext") + } +} + +func TestEncryptDecryptWithWrongKey(t *testing.T) { + key1, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key1: %v", err) + } + key2, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key2: %v", err) + } + + plaintext := "secret data" + + encrypted, err := crypto.EncryptValue(key1, plaintext) + if err != nil { + t.Fatalf("Failed to encrypt: %v", err) + } + + // AES-GCM properly fails with wrong key + _, err = crypto.DecryptValue(key2, encrypted) + if err == nil { + t.Error("DecryptValue should fail with wrong key in GCM mode") + } + + // Should work with correct key + decrypted, err := crypto.DecryptValue(key1, encrypted) + if err != nil { + t.Fatalf("Failed to decrypt with correct key: %v", err) + } + if decrypted != plaintext { + t.Errorf("Expected %q, got %q", plaintext, decrypted) + } +} + +func TestInvalidKeys(t *testing.T) { + plaintext := "test data" + + // Test encryption with invalid key + _, err := crypto.EncryptValue("invalid-key", plaintext) + if err == nil { + t.Error("Expected error for invalid key in encryption") + } + + // Test decryption with invalid key + validKey, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate valid key: %v", err) + } + + encrypted, err := crypto.EncryptValue(validKey, plaintext) + if err != nil { + t.Fatalf("Failed to encrypt: %v", err) + } + + _, err = crypto.DecryptValue("invalid-key", encrypted) + if err == nil { + t.Error("Expected error for invalid key in decryption") + } +} + +func TestInvalidCiphertext(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + // Test ciphertext too short + _, err = crypto.DecryptValue(key, "short") + if err == nil { + t.Error("Expected error for ciphertext too short") + } + + // Test invalid base64 ciphertext + _, err = crypto.DecryptValue(key, "invalid-base64!") + if err == nil { + t.Error("Expected error for invalid base64 ciphertext") + } + + // Test valid base64 but invalid GCM ciphertext + invalidCiphertext := crypto.EncodeValue([]byte("invalid-ciphertext-that-is-long-enough-to-have-nonce")) + _, err = crypto.DecryptValue(key, invalidCiphertext) + if err == nil { + t.Error("Expected error for invalid GCM ciphertext") + } +} + +func TestEncodeDecodeValue(t *testing.T) { + testData := []byte("test data for encoding") + + encoded := crypto.EncodeValue(testData) + if encoded == "" { + t.Error("Encoded value should not be empty") + } + + decoded, err := crypto.DecodeValue(encoded) + if err != nil { + t.Fatalf("Failed to decode value: %v", err) + } + + if string(decoded) != string(testData) { + t.Errorf("Decoded data doesn't match original. Expected %s, got %s", string(testData), string(decoded)) + } + + // Test invalid base64 + _, err = crypto.DecodeValue("invalid-base64!") + if err == nil { + t.Error("Expected error for invalid base64") + } +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/dev.flow b/dev.flow new file mode 100644 index 0000000..7c459df --- /dev/null +++ b/dev.flow @@ -0,0 +1,13 @@ +visibility: private +tags: [development] +executables: + - verb: validate + description: Run the repo's pre-commit script + exec: + dir: // + cmd: | + go fmt ./... + go generate ./... + go mod tidy + golangci-lint run ./... + go test ./... diff --git a/errors.go b/errors.go index dc0e568..21cf2b2 100644 --- a/errors.go +++ b/errors.go @@ -2,8 +2,39 @@ package vault import ( "errors" + "fmt" ) var ( - ErrSecretNotFound = errors.New("secret not found") + ErrSecretNotFound = errors.New("secret not found") + ErrInvalidKey = errors.New("invalid secret key") + ErrNoAccess = errors.New("access denied") + ErrInvalidConfig = errors.New("invalid configuration") + ErrVaultNotFound = errors.New("vault not found") + ErrDecryptionFailed = errors.New("decryption failed") + ErrInvalidRecipient = errors.New("invalid recipient") + ErrPathNotSecure = errors.New("path is not secure") ) + +type VaultPathError struct { + Path string + Err error +} + +func (e *VaultPathError) Error() string { + if e.Path != "" { + return fmt.Sprintf("%s (%s): %v", ErrPathNotSecure, e.Path, e.Err) + } + return fmt.Sprintf("%v: %v", ErrPathNotSecure, e.Err) +} + +func (e *VaultPathError) Unwrap() error { + return e.Err +} + +func NewVaultPathError(path string) *VaultPathError { + return &VaultPathError{ + Path: path, + Err: ErrPathNotSecure, + } +} diff --git a/flow.yaml b/flow.yaml new file mode 100644 index 0000000..c6b1387 --- /dev/null +++ b/flow.yaml @@ -0,0 +1,2 @@ +displayName: vault +descriptionFile: README.md diff --git a/go.mod b/go.mod index 10887c7..d370760 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,13 @@ go 1.24 require ( filippo.io/age v1.2.1 - golang.org/x/crypto v0.24.0 // indirect + golang.org/x/crypto v0.39.0 ) -require golang.org/x/sys v0.21.0 // indirect +require gopkg.in/yaml.v3 v3.0.1 + +require ( + github.com/kr/pretty v0.3.1 // indirect + golang.org/x/sys v0.33.0 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect +) diff --git a/go.sum b/go.sum index 62b31f0..c6fe5ce 100644 --- a/go.sum +++ b/go.sum @@ -2,7 +2,24 @@ c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805 h1:u2qwJeEvnypw+OCPUHmoZE3I c2sp.org/CCTV/age v0.0.0-20240306222714-3ec4d716e805/go.mod h1:FomMrUJ2Lxt5jCLmZkG3FHa72zUprnhd3v/Z18Snm4w= filippo.io/age v1.2.1 h1:X0TZjehAZylOIj4DubWYU1vWQxv9bJpo+Uu2/LGhi1o= filippo.io/age v1.2.1/go.mod h1:JL9ew2lTN+Pyft4RiNGguFfOpewKwSHm5ayKD/A4004= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/local.go b/local.go index ae86438..a9a4d82 100644 --- a/local.go +++ b/local.go @@ -1,21 +1,21 @@ package vault import ( - "bytes" - "encoding/json" "fmt" "os" "path/filepath" - "sync" + "strings" "time" - - "filippo.io/age" ) const ( - currentVaultVersion = 1 - vaultFileBase = "vault" - vaultFileExt = ".age" + vaultFileBase = "vault" + envSource = "env" + fileSource = "file" +) + +var ( + DefaultVaultKeyEnv = "VAULT_KEY" ) type Metadata struct { @@ -23,287 +23,84 @@ type Metadata struct { LastModified time.Time `json:"lastModified"` } -// LocalSecret represents an encrypted secret with its metadata -type LocalSecret struct { - Metadata `json:"metadata"` - - Encrypted []byte `json:"encrypted"` - Nonce []byte `json:"nonce"` - Description string `json:"description,omitempty"` -} - -// LocalState represents the state of the local vault -type LocalState struct { - Metadata `json:"metadata"` - - Version int `json:"version"` - ID string `json:"id"` - Recipients []string `json:"recipients"` - Secrets map[string]string `json:"secrets"` -} - -// LocalVault manages operations on an instance of a local vault -type LocalVault struct { - mu sync.RWMutex - id string - fullPath string - - cfg *LocalConfig - state *LocalState - resolver *IdentityResolver - - identities []age.Identity - recipients []age.Recipient -} - -func NewLocalVault(cfg *Config) (*LocalVault, error) { - path := filepath.Join( - filepath.Clean(cfg.Local.StoragePath), - filepath.Clean(fmt.Sprintf("%s-%s.%s", vaultFileBase, cfg.ID, vaultFileExt)), - ) - - vault := &LocalVault{ - mu: sync.RWMutex{}, - fullPath: path, - id: cfg.ID, - cfg: cfg.Local, - resolver: NewIdentityResolver(cfg.Local.IdentitySources), - } - - ids, err := vault.resolver.ResolveIdentities() - if err != nil { - return nil, fmt.Errorf("failed to resolve identities: %w", err) - } - vault.identities = ids - - if err := vault.load(); err != nil { - return nil, fmt.Errorf("failed to load vault: %w", err) +// validateSecurePath checks if a path is safe to use +func validateSecurePath(path string) error { + if path == "" { + return fmt.Errorf("path cannot be empty") } - if vault.state == nil { - if err := vault.init(); err != nil { - return nil, fmt.Errorf("failed to initialize vault: %w", err) - } - } - - return vault, nil -} - -func (v *LocalVault) init() error { - now := time.Now() - v.state = &LocalState{ - Version: currentVaultVersion, - ID: v.id, - Metadata: Metadata{ - Created: now, - LastModified: now, - }, - Recipients: v.cfg.Recipients, - Secrets: make(map[string]string), - } - - for _, recipientKey := range v.cfg.Recipients { - if err := v.addRecipientToState(recipientKey); err != nil { - return fmt.Errorf("failed to add initial recipient %s: %w", recipientKey, err) - } - } - - if len(v.state.Recipients) == 0 { - // what to do... - return fmt.Errorf("no recipients available for encryption, please add at least one recipient") - } - - if err := v.parseRecipients(); err != nil { - return fmt.Errorf("failed to parse recipients: %w", err) - } - - return v.save() -} - -// load reads the vault file and decrypts its contents -func (v *LocalVault) load() error { - data, err := os.ReadFile(v.fullPath) - if err != nil { - if os.IsNotExist(err) { - return nil - } - return fmt.Errorf("failed to read vault file: %w", err) - } - - if len(data) == 0 { - return nil - } - - // decrypt the vault file using age - r, err := age.Decrypt(bytes.NewReader(data), v.identities...) - if err != nil { - return fmt.Errorf("failed to decrypt vault file - do you have the right key?: %w", err) - } - - var state LocalState - if err := json.NewDecoder(r).Decode(&state); err != nil { - return fmt.Errorf("failed to unmarshal vault state: %w", err) - } - - // store the state and recipients on the LocalVault obj - v.state = &state - if err := v.parseRecipients(); err != nil { - return fmt.Errorf("failed to parse recipients: %w", err) - } - - return nil -} - -// save encrypts and writes the vault contents to disk -func (v *LocalVault) save() error { - if len(v.recipients) == 0 { - return fmt.Errorf("no recipients available for encryption") + // Check for directory traversal attempts + cleanPath := filepath.Clean(path) + if strings.Contains(cleanPath, "..") { + return NewVaultPathError(path) } - v.state.Metadata.LastModified = time.Now() - data, err := json.Marshal(v.state) - if err != nil { - return fmt.Errorf("failed to marshal vault state: %w", err) + // Check for null bytes + if strings.Contains(path, "\x00") { + return NewVaultPathError(path) } - var buf bytes.Buffer - // encrypt the entire file using age - w, err := age.Encrypt(&buf, v.recipients...) + // Ensure the path is absolute after expansion + absPath, err := filepath.Abs(cleanPath) if err != nil { - return fmt.Errorf("failed to create age encryptor: %w", err) - } - if _, err := w.Write(data); err != nil { - return fmt.Errorf("failed to encrypt data: %w", err) - } - if err := w.Close(); err != nil { - return fmt.Errorf("failed to finalize encryption: %w", err) - } - - // write to the file atomically - if err := os.MkdirAll(filepath.Dir(v.fullPath), 0755); err != nil { - return fmt.Errorf("failed to create vault directory: %w", err) + return fmt.Errorf("failed to get absolute path: %w", err) } - tempFile := v.fullPath + ".tmp" - if err := os.WriteFile(tempFile, buf.Bytes(), 0600); err != nil { - return fmt.Errorf("failed to write temp vault file: %w", err) - } - - if err := os.Rename(tempFile, v.fullPath); err != nil { - _ = os.Remove(tempFile) // Clean up on failure - return fmt.Errorf("failed to move vault file: %w", err) - } - - return nil -} - -func (v *LocalVault) ID() string { - return v.id -} - -func (v *LocalVault) GetSecret(key string) (Secret, error) { - v.mu.RLock() - defer v.mu.RUnlock() - - value, exists := v.state.Secrets[key] - if !exists { - return nil, ErrSecretNotFound - } - - return NewSecretValue([]byte(value)), nil -} - -func (v *LocalVault) SetSecret(key string, value Secret) error { - v.mu.Lock() - defer v.mu.Unlock() - - if v.state.Secrets == nil { - v.state.Secrets = make(map[string]string) - } - - v.state.Secrets[key] = value.String() - return v.save() -} -func (v *LocalVault) DeleteSecret(key string) error { - v.mu.Lock() - defer v.mu.Unlock() - - _, exists := v.state.Secrets[key] - if !exists { - return ErrSecretNotFound - } - - delete(v.state.Secrets, key) - return v.save() -} - -func (v *LocalVault) ListSecrets() ([]string, error) { - v.mu.RLock() - defer v.mu.RUnlock() - - keys := make([]string, 0, len(v.state.Secrets)) - for k := range v.state.Secrets { - keys = append(keys, k) + // Basic check that we're not accessing sensitive system directories + systemDirs := []string{"/etc", "/sys", "/proc", "/dev"} + for _, sysDir := range systemDirs { + if strings.HasPrefix(absPath, sysDir) { + return NewVaultPathError(path) + } } - return keys, nil -} - -func (v *LocalVault) HasSecret(key string) (bool, error) { - v.mu.RLock() - defer v.mu.RUnlock() - _, exists := v.state.Secrets[key] - return exists, nil -} - -func (v *LocalVault) Close() error { - // do i need to do anything here? return nil } -func (v *LocalVault) AddRecipient(publicKey string) error { - v.mu.Lock() - defer v.mu.Unlock() - - if err := v.addRecipientToState(publicKey); err != nil { - return err +func expandPath(path string) (string, error) { + if path == "" { + return "", nil } - if err := v.parseRecipients(); err != nil { - return fmt.Errorf("failed to parse recipients: %w", err) - } - - return v.save() -} -func (v *LocalVault) RemoveRecipient(publicKey string) error { - v.mu.Lock() - defer v.mu.Unlock() + var expandedPath string - found := false - for i, rec := range v.state.Recipients { - if rec == publicKey { - v.state.Recipients = append(v.state.Recipients[:i], v.state.Recipients[i+1:]...) - found = true - break + switch path[0] { + case '~': + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %w", err) + } + expandedPath = homeDir + path[1:] + case '/': + expandedPath = path + case '.': + wd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get working directory: %w", err) + } + expandedPath = wd + "/" + path[1:] + case '$': + envVar := path[1:] + if value, exists := os.LookupEnv(envVar); exists { + expandedPath = value + } else { + return "", fmt.Errorf("environment variable %s not found", envVar) + } + default: + wd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("failed to get working directory: %w", err) + } + if wd[len(wd)-1] == '/' { + expandedPath = wd + path + } else { + expandedPath = wd + "/" + path } } - if !found { - return fmt.Errorf("recipient %s not found", publicKey) - } - - if err := v.parseRecipients(); err != nil { - return fmt.Errorf("failed to parse recipients: %w", err) + if err := validateSecurePath(expandedPath); err != nil { + return "", err } - return v.save() -} - -func (v *LocalVault) ListRecipients() ([]string, error) { - v.mu.RLock() - defer v.mu.RUnlock() - - recipients := make([]string, len(v.state.Recipients)) - copy(recipients, v.state.Recipients) // prevent modification of internal state - return recipients, nil + return filepath.Clean(expandedPath), nil } diff --git a/secret.go b/secret.go index 41deef9..3434c1c 100644 --- a/secret.go +++ b/secret.go @@ -1,5 +1,12 @@ package vault +import ( + "crypto/rand" + "fmt" + "regexp" + "runtime" +) + type Secret interface { // PlainTextString returns the decrypted value as a string PlainTextString() string @@ -9,14 +16,46 @@ type Secret interface { // Bytes returns the raw byte representation of the secret Bytes() []byte + + // Zero securely clears the secret from memory + Zero() +} + +// SecureBytes is a wrapper around []byte that provides secure memory handling +type SecureBytes []byte + +// Zero securely clears the byte slice +func (s *SecureBytes) Zero() { + if s != nil && len(*s) > 0 { + // The series of steps below ensures that the memory is cleared securely. It prevents the compiler from + // optimizing away the zeroing operation and is recommended to securely clear sensitive data in Go. + _, _ = rand.Read(*s) + for i := range *s { + (*s)[i] = 0 + } + *s = (*s)[:0] + runtime.GC() + } +} + +// Copy creates a secure copy of the bytes +func (s SecureBytes) Copy() SecureBytes { + if len(s) == 0 { + return SecureBytes{} + } + c := make(SecureBytes, len(s)) + copy(c, s) + return c } type SecretValue struct { - value []byte + value SecureBytes } func NewSecretValue(value []byte) *SecretValue { - return &SecretValue{value: value} + secureValue := make(SecureBytes, len(value)) + copy(secureValue, value) + return &SecretValue{value: secureValue} } func (s *SecretValue) PlainTextString() string { @@ -28,5 +67,23 @@ func (s *SecretValue) String() string { } func (s *SecretValue) Bytes() []byte { - return s.value + // Return a copy to prevent external modification + result := make([]byte, len(s.value)) + copy(result, s.value) + return result +} + +func (s *SecretValue) Zero() { + s.value.Zero() +} + +func ValidateSecretKey(reference string) error { + if reference == "" { + return ErrInvalidKey + } + re := regexp.MustCompile(`^[a-zA-Z0-9-_.]+$`) + if !re.MatchString(reference) { + return fmt.Errorf("%w: must only contain alphanumeric characters, dashes, underscores, and/or dots", ErrInvalidKey) + } + return nil } diff --git a/utils.go b/utils.go deleted file mode 100644 index 6a28bb8..0000000 --- a/utils.go +++ /dev/null @@ -1,33 +0,0 @@ -package vault - -import "os" - -func expandPath(path string) string { - if path == "" { - return "" - } - - switch path[0] { - case '~': - homeDir, _ := os.UserHomeDir() - return homeDir + path[1:] - case '/': - return path - case '.': - wd, _ := os.Getwd() - return wd + "/" + path[1:] - case '$': - envVar := path[1:] - if value, exists := os.LookupEnv(envVar); exists { - return value - } - default: - wd, _ := os.Getwd() - if wd[len(wd)-1] == '/' { - return wd + path - } else { - return wd + "/" + path - } - } - return path -} diff --git a/vault.go b/vault.go index d259062..8087a07 100644 --- a/vault.go +++ b/vault.go @@ -13,6 +13,10 @@ type Provider interface { // ID returns a unique identifier for this vault instance ID() string + + // Metadata returns vault metadata such as creation time + Metadata() Metadata + Close() error } @@ -29,8 +33,10 @@ func New(id string, opts ...Option) (Provider, error) { } switch config.Type { - case ProviderTypeLocal: - return NewLocalVault(config) + case ProviderTypeAge: + return NewAgeVault(config) + case ProviderTypeAES256: + return NewAES256Vault(config) case ProviderTypeExternal: return nil, fmt.Errorf("external vault provider not implemented yet") } @@ -44,63 +50,121 @@ func WithProvider(provider ProviderType) Option { } } -// WithLocalPath sets the local vault storage fullPath +// WithAgePath sets the age vault storage path +func WithAgePath(path string) Option { + return func(c *Config) { + if c.Age == nil { + c.Age = &AgeConfig{} + } + c.Age.StoragePath = path + } +} + +// WithAESPath sets the AES vault storage path +func WithAESPath(path string) Option { + return func(c *Config) { + if c.Aes == nil { + c.Aes = &AesConfig{} + } + c.Aes.StoragePath = path + } +} + +// WithLocalPath sets the local vault storage path (works for both Age and AES based on provider type) func WithLocalPath(path string) Option { return func(c *Config) { - if c.Local == nil { - c.Local = &LocalConfig{} + //nolint:exhaustive + switch c.Type { + case ProviderTypeAge: + if c.Age == nil { + c.Age = &AgeConfig{} + } + c.Age.StoragePath = path + case ProviderTypeAES256: + if c.Aes == nil { + c.Aes = &AesConfig{} + } + c.Aes.StoragePath = path } - c.Local.StoragePath = path } } -// WithLocalIdentityFromEnv specifies to retrieve the key from an environment variable for local vaults -func WithLocalIdentityFromEnv(envVar string) Option { +// WithAgeIdentityFromEnv specifies to retrieve the age identity from an environment variable +func WithAgeIdentityFromEnv(envVar string) Option { return func(c *Config) { - if c.Local == nil { - c.Local = &LocalConfig{} + if c.Age == nil { + c.Age = &AgeConfig{} } - if len(c.Local.IdentitySources) == 0 { - c.Local.IdentitySources = make([]IdentitySource, 0) + if len(c.Age.IdentitySources) == 0 { + c.Age.IdentitySources = make([]IdentitySource, 0) } - c.Local.IdentitySources = append( - c.Local.IdentitySources, + c.Age.IdentitySources = append( + c.Age.IdentitySources, IdentitySource{Type: "env", Name: envVar}, ) } } -// WithLocalIdentityFromFile specifies to retrieve the key from a file for local vaults -func WithLocalIdentityFromFile(path string) Option { +// WithAgeIdentityFromFile specifies to retrieve the age identity from a file +func WithAgeIdentityFromFile(path string) Option { return func(c *Config) { - if c.Local == nil { - c.Local = &LocalConfig{} + if c.Age == nil { + c.Age = &AgeConfig{} } - if len(c.Local.IdentitySources) == 0 { - c.Local.IdentitySources = make([]IdentitySource, 0) + if len(c.Age.IdentitySources) == 0 { + c.Age.IdentitySources = make([]IdentitySource, 0) } - c.Local.IdentitySources = append( - c.Local.IdentitySources, + c.Age.IdentitySources = append( + c.Age.IdentitySources, IdentitySource{Type: "file", Path: path}, ) } } -// WithRecipients sets the recipients for local vaults -func WithRecipients(recipients ...string) Option { +// WithAESKeyFromEnv specifies to retrieve the AES key from an environment variable +func WithAESKeyFromEnv(envVar string) Option { + return func(c *Config) { + if c.Aes == nil { + c.Aes = &AesConfig{} + } + if len(c.Aes.KeySource) == 0 { + c.Aes.KeySource = make([]KeySource, 0) + } + c.Aes.KeySource = append( + c.Aes.KeySource, + KeySource{Type: "env", Name: envVar}, + ) + } +} + +// WithAESKeyFromFile specifies to retrieve the AES key from a file +func WithAESKeyFromFile(path string) Option { + return func(c *Config) { + if c.Aes == nil { + c.Aes = &AesConfig{} + } + if len(c.Aes.KeySource) == 0 { + c.Aes.KeySource = make([]KeySource, 0) + } + c.Aes.KeySource = append( + c.Aes.KeySource, + KeySource{Type: "file", Path: path}, + ) + } +} + +// WithAgeRecipients sets the recipients for age vaults +func WithAgeRecipients(recipients ...string) Option { return func(c *Config) { - if c.Local == nil { - c.Local = &LocalConfig{} + if c.Age == nil { + c.Age = &AgeConfig{} } - // if len(c.Local.Recipients) == 0 { - // c.Local.Recipients = make([]string, len(recipients)) - // } - c.Local.Recipients = append(c.Local.Recipients, recipients...) + c.Age.Recipients = append(c.Age.Recipients, recipients...) } } // WithExternalConfig sets the external vault configuration. FOR TESTING PURPOSES ONLY. -// TODO: break this down when the external provider is fully impelemented +// TODO: break this down when the external provider is fully implemented func WithExternalConfig(cfg *ExternalConfig) Option { return func(c *Config) { c.Type = ProviderTypeExternal diff --git a/vault_test.go b/vault_test.go new file mode 100644 index 0000000..f8f9511 --- /dev/null +++ b/vault_test.go @@ -0,0 +1,384 @@ +package vault_test + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/jahvon/vault" +) + +func TestVaultInterface(t *testing.T) { + tests := []struct { + name string + provider vault.ProviderType + setup func(t *testing.T, dir string) vault.Provider + }{ + { + name: "AES256 Vault", + provider: vault.ProviderTypeAES256, + setup: setupAESVault, + }, + { + name: "Age Vault", + provider: vault.ProviderTypeAge, + setup: setupAgeVault, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + + v := tt.setup(t, tempDir) + defer v.Close() + + testBasicOperations(t, v) + testSecretOperations(t, v) + testPersistence(t, v, tt.provider, tempDir) + }) + } +} + +func setupAESVault(t *testing.T, dir string) vault.Provider { + // Only generate a new key if one isn't already set + if os.Getenv(vault.DefaultVaultKeyEnv) == "" { + key, err := vault.GenerateEncryptionKey() + if err != nil { + t.Fatalf("Failed to generate test key: %v", err) + } + t.Setenv(vault.DefaultVaultKeyEnv, key) + } + + v, err := vault.New("test-aes", + vault.WithProvider(vault.ProviderTypeAES256), + vault.WithAESPath(dir), + vault.WithAESKeyFromEnv(vault.DefaultVaultKeyEnv), + ) + if err != nil { + t.Fatalf("Failed to create AES vault: %v", err) + } + + return v +} + +func setupAgeVault(t *testing.T, dir string) vault.Provider { + testIdentity := "AGE-SECRET-KEY-1LC563A3EG4TLDL5EQE0YP5ZSJW8NADURXLZ8WVM00DMKG60URRNQ5TRZH0" + testRecipient := "age1wnhg53pg2qfsfxwvxvlg6pygw5uzwcyhj2dqhg0k83fvjexf9pzsxqdvs0" + + keyFile := filepath.Join(dir, "test-key.txt") + err := os.WriteFile(keyFile, []byte(testIdentity), 0600) + if err != nil { + t.Fatalf("Failed to write test key file: %v", err) + } + + v, err := vault.New("test-age", + vault.WithProvider(vault.ProviderTypeAge), + vault.WithAgePath(dir), + vault.WithAgeIdentityFromFile(keyFile), + vault.WithAgeRecipients(testRecipient), + ) + if err != nil { + t.Fatalf("Failed to create Age vault: %v", err) + } + + return v +} + +func testBasicOperations(t *testing.T, v vault.Provider) { + id := v.ID() + if id == "" { + t.Error("Vault ID should not be empty") + } + + secrets, err := v.ListSecrets() + if err != nil { + t.Fatalf("Failed to list secrets: %v", err) + } + if len(secrets) != 0 { + t.Errorf("New vault should have 0 secrets, got %d", len(secrets)) + } + + exists, err := v.HasSecret("nonexistent") + if err != nil { + t.Fatalf("Failed to check secret existence: %v", err) + } + if exists { + t.Error("HasSecret should return false for nonexistent secret") + } + + _, err = v.GetSecret("nonexistent") + if !errors.Is(err, vault.ErrSecretNotFound) { + t.Errorf("Expected ErrSecretNotFound, got: %v", err) + } +} + +func testSecretOperations(t *testing.T, v vault.Provider) { + testCases := []struct { + key string + value string + }{ + {"api-key", "my-secret-api-key"}, + {"db-password", "super-secret-password"}, + {"special-chars", "!@#$%^&*()_+-={}[]|\\:;\"'<>?,./ ~`"}, + {"unicode", "🔐 secret with emoji 🚀"}, + {"empty", ""}, + } + + for _, tc := range testCases { + secret := vault.NewSecretValue([]byte(tc.value)) + err := v.SetSecret(tc.key, secret) + if err != nil { + t.Fatalf("Failed to set secret %s: %v", tc.key, err) + } + } + + // Verify all secrets can be retrieved + for _, tc := range testCases { + secret, err := v.GetSecret(tc.key) + if err != nil { + t.Fatalf("Failed to get secret %s: %v", tc.key, err) + } + if secret.PlainTextString() != tc.value { + t.Errorf("Secret %s: expected %q, got %q", tc.key, tc.value, secret.PlainTextString()) + } + + // Test HasSecret + exists, err := v.HasSecret(tc.key) + if err != nil { + t.Fatalf("Failed to check secret %s existence: %v", tc.key, err) + } + if !exists { + t.Errorf("HasSecret should return true for %s", tc.key) + } + } + + // Test ListSecrets + secrets, err := v.ListSecrets() + if err != nil { + t.Fatalf("Failed to list secrets: %v", err) + } + if len(secrets) != len(testCases) { + t.Errorf("Expected %d secrets, got %d", len(testCases), len(secrets)) + } + + // Verify all expected keys are present + keyMap := make(map[string]bool) + for _, key := range secrets { + keyMap[key] = true + } + for _, tc := range testCases { + if !keyMap[tc.key] { + t.Errorf("Missing secret key: %s", tc.key) + } + } + + // Test updating existing secret + newSecret := vault.NewSecretValue([]byte("updated-value")) + err = v.SetSecret("api-key", newSecret) + if err != nil { + t.Fatalf("Failed to update secret: %v", err) + } + + retrievedSecret, err := v.GetSecret("api-key") + if err != nil { + t.Fatalf("Failed to get updated secret: %v", err) + } + if retrievedSecret.PlainTextString() != "updated-value" { + t.Errorf("Updated secret: expected 'updated-value', got %q", retrievedSecret.PlainTextString()) + } + + // Test DeleteSecret + err = v.DeleteSecret("api-key") + if err != nil { + t.Fatalf("Failed to delete secret: %v", err) + } + + _, err = v.GetSecret("api-key") + if !errors.Is(err, vault.ErrSecretNotFound) { + t.Errorf("Expected ErrSecretNotFound after deletion, got: %v", err) + } + + exists, err := v.HasSecret("api-key") + if err != nil { + t.Fatalf("Failed to check deleted secret existence: %v", err) + } + if exists { + t.Error("HasSecret should return false for deleted secret") + } + + // Verify list count decreased + secrets, err = v.ListSecrets() + if err != nil { + t.Fatalf("Failed to list secrets after deletion: %v", err) + } + if len(secrets) != len(testCases)-1 { + t.Errorf("Expected %d secrets after deletion, got %d", len(testCases)-1, len(secrets)) + } + + // Test deleting nonexistent secret + err = v.DeleteSecret("nonexistent") + if !errors.Is(err, vault.ErrSecretNotFound) { + t.Errorf("Expected ErrSecretNotFound when deleting nonexistent secret, got: %v", err) + } +} + +func testPersistence(t *testing.T, v vault.Provider, provider vault.ProviderType, dir string) { + // Store a test secret + testSecret := vault.NewSecretValue([]byte("persistence-test")) + err := v.SetSecret("persist-test", testSecret) + if err != nil { + t.Fatalf("Failed to set persistence test secret: %v", err) + } + + // Close the vault + err = v.Close() + if err != nil { + t.Fatalf("Failed to close vault: %v", err) + } + + // Verify encrypted file exists + var pattern string + switch provider { + case vault.ProviderTypeAES256: + pattern = "*.enc" + case vault.ProviderTypeAge: + pattern = "*.age" + } + + files, err := filepath.Glob(filepath.Join(dir, pattern)) + if err != nil { + t.Fatalf("Failed to find vault files: %v", err) + } + if len(files) == 0 { + t.Fatalf("No encrypted vault file found (pattern: %s)", pattern) + } + + // Recreate vault and verify secret persisted + var newVault vault.Provider + switch provider { + case vault.ProviderTypeAES256: + newVault = setupAESVault(t, dir) + case vault.ProviderTypeAge: + newVault = setupAgeVault(t, dir) + } + defer newVault.Close() + + retrievedSecret, err := newVault.GetSecret("persist-test") + if err != nil { + t.Fatalf("Failed to get persisted secret: %v", err) + } + if retrievedSecret.PlainTextString() != "persistence-test" { + t.Errorf("Persisted secret: expected 'persistence-test', got %q", retrievedSecret.PlainTextString()) + } + + // Test that Metadata is preserved + metadata := newVault.Metadata() + if metadata.Created.IsZero() { + t.Error("Metadata creation time should not be zero") + } + if metadata.LastModified.IsZero() { + t.Error("Metadata last modified time should not be zero") + } + if metadata.Created.After(metadata.LastModified) { + t.Error("Metadata creation time should not be after last modified time") + } + if metadata.LastModified.Before(metadata.Created) { + t.Error("Metadata last modified time should not be before creation time") + } +} + +func TestSecretValidation(t *testing.T) { + tempDir := t.TempDir() + v := setupAESVault(t, tempDir) + defer v.Close() + + // Test invalid secret keys + invalidKeys := []string{ + "", // empty + "key with spaces", + "key/with/slashes", + "key\\with\\backslashes", + "key\nwith\nnewlines", + "key\twith\ttabs", + } + + for _, key := range invalidKeys { + secret := vault.NewSecretValue([]byte("test")) + err := v.SetSecret(key, secret) + if err == nil { + t.Errorf("Expected error for invalid key %q, but got none", key) + } + } + + // Test valid keys + validKeys := []string{ + "simple-key", + "key_with_underscores", + "key-with-dashes", + "key123", + "UPPERCASE", + "mixedCase", + "key.with.dots", + } + + for _, key := range validKeys { + secret := vault.NewSecretValue([]byte("test")) + err := v.SetSecret(key, secret) + if err != nil { + t.Errorf("Expected no error for valid key %q, but got: %v", key, err) + } + } +} + +func TestConcurrentAccess(t *testing.T) { + tempDir := t.TempDir() + v := setupAESVault(t, tempDir) + defer v.Close() + + // Test concurrent reads and writes + done := make(chan bool) + errs := make(chan error, 10) + + // Concurrent writers + for i := 0; i < 5; i++ { + go func(id int) { + for j := 0; j < 10; j++ { + key := fmt.Sprintf("key-%d-%d", id, j) + value := fmt.Sprintf("value-%d-%d", id, j) + secret := vault.NewSecretValue([]byte(value)) + if err := v.SetSecret(key, secret); err != nil { + errs <- err + return + } + } + done <- true + }(i) + } + + // Concurrent readers + for i := 0; i < 3; i++ { + go func() { + for j := 0; j < 20; j++ { + _, _ = v.ListSecrets() // Don't care about errors here since secrets are being added concurrently + time.Sleep(time.Millisecond) + } + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < 8; i++ { + select { + case err := <-errs: + t.Fatalf("Concurrent operation failed: %v", err) + case <-done: + // Success + case <-time.After(10 * time.Second): + t.Fatal("Concurrent test timed out") + } + } +}