diff --git a/pkg/rdsdiff/compare.go b/pkg/rdsdiff/compare.go new file mode 100644 index 00000000..d8018061 --- /dev/null +++ b/pkg/rdsdiff/compare.go @@ -0,0 +1,502 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "github.com/sergi/go-diff/diffmatchpatch" + sigsyaml "sigs.k8s.io/yaml" +) + +var slashRe = regexp.MustCompile(`/+`) + +// CRKeyFromObject builds a stable key for a CR: kind_namespace_name, or kind_name for cluster-scoped. +func CRKeyFromObject(obj map[string]any) string { + kind, _ := obj["kind"].(string) + if kind == "" { + kind = "Unknown" + } + meta, _ := obj["metadata"].(map[string]any) + if meta == nil { + meta = make(map[string]any) + } + name, _ := meta["name"].(string) + if name == "" { + name = "unnamed" + } + ns, _ := meta["namespace"].(string) + var key string + if ns == "" { + key = kind + "_" + name + } else { + key = kind + "_" + ns + "_" + name + } + return slashRe.ReplaceAllString(key, "-") +} + +// CRPair is a key and object pair from a Policy document. +type CRPair struct { + Key string + Obj map[string]any +} + +// ExtractCRsFromPolicyDoc returns CRPair for each CR embedded in a Policy document. +func ExtractCRsFromPolicyDoc(doc map[string]any) []CRPair { + if kind, _ := doc["kind"].(string); kind != "Policy" { + return nil + } + var out []CRPair + spec, _ := doc["spec"].(map[string]any) + if spec == nil { + return nil + } + templates, _ := spec["policy-templates"].([]any) + for _, t := range templates { + tm, _ := t.(map[string]any) + if tm == nil { + continue + } + od, _ := tm["objectDefinition"].(map[string]any) + if od == nil { + continue + } + innerSpec, _ := od["spec"].(map[string]any) + if innerSpec == nil { + continue + } + objTemplates, _ := innerSpec["object-templates"].([]any) + for _, ot := range objTemplates { + otm, _ := ot.(map[string]any) + if otm == nil { + continue + } + obj, _ := otm["objectDefinition"].(map[string]any) + if obj == nil { + continue + } + key := CRKeyFromObject(obj) + out = append(out, CRPair{Key: key, Obj: obj}) + } + } + return out +} + +// ExtractCRs reads all *.yaml in generatedDir, collects CRs from Policy files, writes extractedDir with one file per key. +func ExtractCRs(generatedDir, extractedDir string) error { + generatedDir = filepath.Clean(generatedDir) + extractedDir = filepath.Clean(extractedDir) + if err := os.MkdirAll(extractedDir, 0o750); err != nil { + return fmt.Errorf("mkdir extracted dir: %w", err) + } + entries, err := readCleanDir(generatedDir) + if err != nil { + return fmt.Errorf("read generated dir: %w", err) + } + collected := make(map[string]map[string]any) + for _, e := range entries { + if e.IsDir() || !strings.HasSuffix(strings.ToLower(e.Name()), ".yaml") { + continue + } + path, err := safeJoinPath(generatedDir, e.Name()) + if err != nil { + return fmt.Errorf("unsafe entry in generated dir: %w", err) + } + data, err := readCleanFile(path) + if err != nil { + return fmt.Errorf("read %s: %w", path, err) + } + docs := unmarshalMultiDocYAML(data) + for _, doc := range docs { + if doc == nil { + continue + } + for _, pair := range ExtractCRsFromPolicyDoc(doc) { + collected[pair.Key] = pair.Obj + } + } + } + keys := make([]string, 0, len(collected)) + for k := range collected { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + norm, err := normalizeYAML(collected[key]) + if err != nil { + return err + } + outPath, err := safeJoinPath(extractedDir, key+".yaml") + if err != nil { + return fmt.Errorf("unsafe CR key %q: %w", key, err) + } + if err := writeCleanFile(outPath, []byte(norm)); err != nil { + return fmt.Errorf("write %s: %w", outPath, err) + } + } + return nil +} + +func unmarshalMultiDocYAML(data []byte) []map[string]any { + var docs []map[string]any + content := string(data) + // Strip leading document marker so splitting on "\n---" works for all cases + content = strings.TrimPrefix(content, "---\n") + content = strings.TrimPrefix(content, "---\r\n") + parts := strings.Split(content, "\n---") + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + var doc map[string]any + if err := sigsyaml.Unmarshal([]byte(part), &doc); err != nil { + continue + } + if doc != nil { + docs = append(docs, doc) + } + } + if len(docs) == 0 { + var single map[string]any + if err := sigsyaml.Unmarshal(data, &single); err == nil && single != nil { + docs = []map[string]any{single} + } + } + return docs +} + +func getKeysFromExtractedDir(dir string) ([]string, error) { + entries, err := readCleanDir(dir) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, nil + } + return nil, fmt.Errorf("read dir %s: %w", dir, err) + } + var keys []string + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if strings.HasSuffix(strings.ToLower(name), ".yaml") { + keys = append(keys, strings.TrimSuffix(name, filepath.Ext(name))) + } + } + sort.Strings(keys) + return keys, nil +} + +// GetKeysFromExtractedDir returns sorted list of CR keys (filename stem without .yaml). +func GetKeysFromExtractedDir(dir string) ([]string, error) { + return getKeysFromExtractedDir(dir) +} + +// Correlate returns onlyOld, onlyNew, inBoth key lists (all sorted). +func Correlate(oldDir, newDir string) (onlyOld, onlyNew, inBoth []string, err error) { + return correlate(oldDir, newDir) +} + +// NormalizeYAML marshals obj with sorted keys for stable diff. +func NormalizeYAML(obj map[string]any) (string, error) { + return normalizeYAML(obj) +} + +// ComputeDiff loads both YAMLs, normalizes them, returns unified diff text. +func ComputeDiff(oldPath, newPath string) (string, error) { + return computeDiff(oldPath, newPath) +} + +// BuildSummary builds the summary block text. +func BuildSummary(onlyOld, onlyNew, inBoth []string, numDiffer int) string { + return buildSummary(onlyOld, onlyNew, inBoth, numDiffer) +} + +func correlate(oldDir, newDir string) (onlyOld, onlyNew, inBoth []string, err error) { + oldKeys, err := getKeysFromExtractedDir(oldDir) + if err != nil { + return nil, nil, nil, err + } + newKeys, err := getKeysFromExtractedDir(newDir) + if err != nil { + return nil, nil, nil, err + } + oldSet := make(map[string]struct{}) + for _, k := range oldKeys { + oldSet[k] = struct{}{} + } + newSet := make(map[string]struct{}) + for _, k := range newKeys { + newSet[k] = struct{}{} + } + for _, k := range oldKeys { + if _, in := newSet[k]; !in { + onlyOld = append(onlyOld, k) + } else { + inBoth = append(inBoth, k) + } + } + for _, k := range newKeys { + if _, in := oldSet[k]; !in { + onlyNew = append(onlyNew, k) + } + } + sort.Strings(onlyOld) + sort.Strings(onlyNew) + sort.Strings(inBoth) + return onlyOld, onlyNew, inBoth, nil +} + +func sortMapKeys(m map[string]any) map[string]any { + out := make(map[string]any, len(m)) + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + v := m[k] + switch tv := v.(type) { + case map[string]any: + out[k] = sortMapKeys(tv) + case []any: + out[k] = sortSlice(tv) + default: + out[k] = v + } + } + return out +} + +func sortSlice(s []any) []any { + out := make([]any, len(s)) + for i, v := range s { + switch tv := v.(type) { + case map[string]any: + out[i] = sortMapKeys(tv) + case []any: + out[i] = sortSlice(tv) + default: + out[i] = v + } + } + return out +} + +func normalizeYAML(obj map[string]any) (string, error) { + sorted := sortMapKeys(obj) + b, err := sigsyaml.Marshal(sorted) + if err != nil { + return "", fmt.Errorf("marshal yaml: %w", err) + } + return string(b), nil +} + +func computeDiff(oldPath, newPath string) (string, error) { + oldData, err := readCleanFile(oldPath) + if err != nil { + return "", fmt.Errorf("read old %s: %w", oldPath, err) + } + newData, err := readCleanFile(newPath) + if err != nil { + return "", fmt.Errorf("read new %s: %w", newPath, err) + } + var oldObj, newObj map[string]any + if err := sigsyaml.Unmarshal(oldData, &oldObj); err != nil { + oldObj = nil + } + if err := sigsyaml.Unmarshal(newData, &newObj); err != nil { + newObj = nil + } + if oldObj == nil { + oldObj = make(map[string]any) + } + if newObj == nil { + newObj = make(map[string]any) + } + oldStr, err := normalizeYAML(oldObj) + if err != nil { + return "", fmt.Errorf("normalize old: %w", err) + } + newStr, err := normalizeYAML(newObj) + if err != nil { + return "", fmt.Errorf("normalize new: %w", err) + } + return unifiedDiff(oldStr, newStr, filepath.Base(oldPath), filepath.Base(newPath)), nil +} + +// unifiedDiff computes a line-level unified diff between two strings using diffmatchpatch. +func unifiedDiff(oldStr, newStr, fromFile, toFile string) string { + dmp := diffmatchpatch.New() + a, b, c := dmp.DiffLinesToChars(oldStr, newStr) + diffs := dmp.DiffMain(a, b, false) + diffs = dmp.DiffCharsToLines(diffs, c) + diffs = dmp.DiffCleanupSemantic(diffs) + + allEqual := true + for _, d := range diffs { + if d.Type != diffmatchpatch.DiffEqual { + allEqual = false + break + } + } + if allEqual { + return "" + } + + var buf strings.Builder + buf.WriteString("--- " + fromFile + "\n") + buf.WriteString("+++ " + toFile + "\n") + for _, d := range diffs { + lines := strings.Split(strings.TrimRight(d.Text, "\n"), "\n") + for _, line := range lines { + switch d.Type { + case diffmatchpatch.DiffEqual: + buf.WriteString(" " + line + "\n") + case diffmatchpatch.DiffDelete: + buf.WriteString("-" + line + "\n") + case diffmatchpatch.DiffInsert: + buf.WriteString("+" + line + "\n") + } + } + } + return buf.String() +} + +func buildSummary(onlyOld, onlyNew, inBoth []string, numDiffer int) string { + return fmt.Sprintf("Comparison summary\n==================\nOnly in old: %d\nOnly in new: %d\nIn both: %d\nDiffer: %d\n\n", + len(onlyOld), len(onlyNew), len(inBoth), numDiffer) +} + +// ComparisonJSON is the shape of the comparison JSON file. +type ComparisonJSON struct { + OnlyOld []string `json:"only_old"` + OnlyNew []string `json:"only_new"` + InBoth []string `json:"in_both"` + Diffs map[string]CRDiffEntry `json:"diffs"` + Summary string `json:"summary"` + OldExtracted string `json:"old_extracted"` + NewExtracted string `json:"new_extracted"` +} + +// CRDiffEntry holds old/new content and diff text for one CR key. +type CRDiffEntry struct { + OldContent string `json:"old_content"` + NewContent string `json:"new_content"` + DiffText string `json:"diff_text"` +} + +// RunCompare extracts CRs from generated policy directories, compares them, and writes a report and +// comparison JSON under sessionDir. Returns the summary text. +func RunCompare(oldGenerated, newGenerated, sessionDir string) (string, error) { + oldGenerated = filepath.Clean(oldGenerated) + newGenerated = filepath.Clean(newGenerated) + sessionDir = filepath.Clean(sessionDir) + oldExtracted := filepath.Join(sessionDir, "old-extracted") + newExtracted := filepath.Join(sessionDir, "new-extracted") + if err := os.MkdirAll(oldExtracted, 0o750); err != nil { + return "", fmt.Errorf("mkdir old-extracted: %w", err) + } + if err := os.MkdirAll(newExtracted, 0o750); err != nil { + return "", fmt.Errorf("mkdir new-extracted: %w", err) + } + if err := ExtractCRs(oldGenerated, oldExtracted); err != nil { + return "", fmt.Errorf("extract old: %w", err) + } + if err := ExtractCRs(newGenerated, newExtracted); err != nil { + return "", fmt.Errorf("extract new: %w", err) + } + reportPath := filepath.Join(sessionDir, "diff-report.txt") + comparisonJSONPath := filepath.Join(sessionDir, "comparison.json") + return Run(oldExtracted, newExtracted, reportPath, comparisonJSONPath) +} + +// Run performs correlate, diff, writes report and comparison JSON; returns summary. +func Run(oldExtracted, newExtracted, reportPath, comparisonJSONPath string) (string, error) { + oldExtracted = filepath.Clean(oldExtracted) + newExtracted = filepath.Clean(newExtracted) + reportPath = filepath.Clean(reportPath) + comparisonJSONPath = filepath.Clean(comparisonJSONPath) + + onlyOld, onlyNew, inBoth, err := correlate(oldExtracted, newExtracted) + if err != nil { + return "", fmt.Errorf("correlate: %w", err) + } + + diffs := make(map[string]CRDiffEntry) + numDiffer := 0 + for _, key := range inBoth { + oldPath, err := safeJoinPath(oldExtracted, key+".yaml") + if err != nil { + return "", fmt.Errorf("unsafe key %q for old: %w", key, err) + } + newPath, err := safeJoinPath(newExtracted, key+".yaml") + if err != nil { + return "", fmt.Errorf("unsafe key %q for new: %w", key, err) + } + diffText, err := computeDiff(oldPath, newPath) + if err != nil { + return "", fmt.Errorf("diff %s: %w", key, err) + } + oldContent, err := readCleanFile(oldPath) + if err != nil { + return "", fmt.Errorf("read old content %s: %w", key, err) + } + newContent, err := readCleanFile(newPath) + if err != nil { + return "", fmt.Errorf("read new content %s: %w", key, err) + } + diffs[key] = CRDiffEntry{ + OldContent: string(oldContent), + NewContent: string(newContent), + DiffText: diffText, + } + if strings.TrimSpace(diffText) != "" { + numDiffer++ + } + } + + summary := buildSummary(onlyOld, onlyNew, inBoth, numDiffer) + + reportLines := []string{summary, ""} + for _, key := range inBoth { + if strings.TrimSpace(diffs[key].DiffText) != "" { + reportLines = append(reportLines, "--- "+key+" ---", diffs[key].DiffText, "") + } + } + if err := os.MkdirAll(filepath.Dir(reportPath), 0o750); err != nil { + return "", fmt.Errorf("mkdir report dir: %w", err) + } + if err := writeCleanFile(reportPath, []byte(strings.Join(reportLines, "\n"))); err != nil { + return "", fmt.Errorf("write report: %w", err) + } + + comparison := ComparisonJSON{ + OnlyOld: onlyOld, + OnlyNew: onlyNew, + InBoth: inBoth, + Diffs: diffs, + Summary: summary, + OldExtracted: oldExtracted, + NewExtracted: newExtracted, + } + comparisonBytes, err := json.MarshalIndent(comparison, "", " ") + if err != nil { + return "", fmt.Errorf("marshal comparison: %w", err) + } + if err := os.MkdirAll(filepath.Dir(comparisonJSONPath), 0o750); err != nil { + return "", fmt.Errorf("mkdir comparison dir: %w", err) + } + if err := writeCleanFile(comparisonJSONPath, comparisonBytes); err != nil { + return "", fmt.Errorf("write comparison: %w", err) + } + return summary, nil +} diff --git a/pkg/rdsdiff/compare_test.go b/pkg/rdsdiff/compare_test.go new file mode 100644 index 00000000..199fa5e3 --- /dev/null +++ b/pkg/rdsdiff/compare_test.go @@ -0,0 +1,435 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + sigsyaml "sigs.k8s.io/yaml" +) + +const minimalPolicyYAML = ` +apiVersion: policy.open-cluster-management.io/v1 +kind: Policy +metadata: + name: test-policy + namespace: ztp-common +spec: + disabled: false + policy-templates: + - objectDefinition: + apiVersion: policy.open-cluster-management.io/v1 + kind: ConfigurationPolicy + metadata: + name: test-config-policy + spec: + object-templates: + - complianceType: musthave + objectDefinition: + apiVersion: v1 + kind: ConfigMap + metadata: + name: my-config + namespace: openshift-monitoring + data: + key: value + - complianceType: musthave + objectDefinition: + apiVersion: operator.openshift.io/v1alpha1 + kind: ImageContentSourcePolicy + metadata: + name: my-icsp + spec: + repositoryDigestMirrors: [] +` + +const policyWithDuplicateKey = ` +apiVersion: policy.open-cluster-management.io/v1 +kind: Policy +metadata: + name: dup-policy + namespace: ztp-common +spec: + policy-templates: + - objectDefinition: + spec: + object-templates: + - objectDefinition: + kind: ConfigMap + metadata: + name: same-name + namespace: ns1 + data: + first: a + - objectDefinition: + kind: ConfigMap + metadata: + name: same-name + namespace: ns1 + data: + second: b +` + +func TestGetKeysFromExtractedDir_Empty(t *testing.T) { + dir := t.TempDir() + keys, err := GetKeysFromExtractedDir(dir) + require.NoError(t, err) + assert.Empty(t, keys) +} + +func TestGetKeysFromExtractedDir_Missing(t *testing.T) { + keys, err := GetKeysFromExtractedDir(filepath.Join(t.TempDir(), "nonexistent")) + require.NoError(t, err) + assert.Empty(t, keys) +} + +func TestGetKeysFromExtractedDir_YAMLStems(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "ConfigMap_ns_foo.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "Secret_ns_bar.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "readme.txt"), []byte("x"), 0o600)) + + keys, err := GetKeysFromExtractedDir(dir) + require.NoError(t, err) + assert.ElementsMatch(t, []string{"ConfigMap_ns_foo", "Secret_ns_bar"}, keys) +} + +func TestCorrelate(t *testing.T) { + oldDir := t.TempDir() + newDir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(oldDir, "OnlyOld.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, "OnlyNew.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(oldDir, "Both.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, "Both.yaml"), []byte("{}"), 0o600)) + + onlyOld, onlyNew, inBoth, err := Correlate(oldDir, newDir) + require.NoError(t, err) + assert.Equal(t, []string{"OnlyOld"}, onlyOld) + assert.Equal(t, []string{"OnlyNew"}, onlyNew) + assert.Equal(t, []string{"Both"}, inBoth) +} + +func TestCorrelate_InBothSorted(t *testing.T) { + oldDir := t.TempDir() + newDir := t.TempDir() + for _, name := range []string{"Z", "A", "M"} { + require.NoError(t, os.WriteFile(filepath.Join(oldDir, name+".yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, name+".yaml"), []byte("{}"), 0o600)) + } + _, _, inBoth, err := Correlate(oldDir, newDir) + require.NoError(t, err) + assert.Equal(t, []string{"A", "M", "Z"}, inBoth) +} + +func TestNormalizeYAML_SortedKeys(t *testing.T) { + obj := map[string]any{"b": 2, "a": 1} + out, err := NormalizeYAML(obj) + require.NoError(t, err) + assert.Contains(t, out, "a:") + assert.Contains(t, out, "b:") + aIdx := strings.Index(out, "a:") + bIdx := strings.Index(out, "b:") + assert.Less(t, aIdx, bIdx) +} + +func TestComputeDiff_Identical(t *testing.T) { + dir := t.TempDir() + content := "key: value\n" + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.yaml"), []byte(content), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.yaml"), []byte(content), 0o600)) + + diff, err := ComputeDiff(filepath.Join(dir, "a.yaml"), filepath.Join(dir, "b.yaml")) + require.NoError(t, err) + assert.Empty(t, diff) +} + +func TestComputeDiff_Different(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "a.yaml"), []byte("key: value1\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "b.yaml"), []byte("key: value2\n"), 0o600)) + + diff, err := ComputeDiff(filepath.Join(dir, "a.yaml"), filepath.Join(dir, "b.yaml")) + require.NoError(t, err) + assert.True(t, strings.Contains(diff, "value1") || strings.Contains(diff, "value2")) + assert.True(t, strings.Contains(diff, "---") || strings.Contains(diff, "+++")) +} + +func TestBuildSummary(t *testing.T) { + s := BuildSummary([]string{"a"}, []string{"b"}, []string{"c", "d"}, 1) + assert.Contains(t, s, "Only in old: 1") + assert.Contains(t, s, "Only in new: 1") + assert.Contains(t, s, "In both: 2") + assert.Contains(t, s, "Differ: 1") +} + +func TestRun_ReportAndJSON(t *testing.T) { + root := t.TempDir() + oldDir := filepath.Join(root, "old") + newDir := filepath.Join(root, "new") + require.NoError(t, os.MkdirAll(oldDir, 0o750)) + require.NoError(t, os.MkdirAll(newDir, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(oldDir, "Same.yaml"), []byte("x: 1\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, "Same.yaml"), []byte("x: 1\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(oldDir, "OnlyO.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, "OnlyN.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(oldDir, "Diff.yaml"), []byte("a: 1\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, "Diff.yaml"), []byte("a: 2\n"), 0o600)) + + reportPath := filepath.Join(root, "report.txt") + jsonPath := filepath.Join(root, "comparison.json") + summary, err := Run(oldDir, newDir, reportPath, jsonPath) + require.NoError(t, err) + assert.Contains(t, summary, "Only in old: 1") + assert.Contains(t, summary, "Only in new: 1") + assert.Contains(t, summary, "In both: 2") + + reportData, err := os.ReadFile(reportPath) + require.NoError(t, err) + assert.Contains(t, string(reportData), "Comparison summary") + + jsonData, err := os.ReadFile(jsonPath) + require.NoError(t, err) + var data ComparisonJSON + require.NoError(t, json.Unmarshal(jsonData, &data)) + assert.Equal(t, []string{"OnlyO"}, data.OnlyOld) + assert.Equal(t, []string{"OnlyN"}, data.OnlyNew) + assert.ElementsMatch(t, []string{"Diff", "Same"}, data.InBoth) + assert.Contains(t, data.Diffs, "Diff") + assert.NotEmpty(t, data.Diffs["Diff"].OldContent) + assert.NotEmpty(t, data.Diffs["Diff"].NewContent) + assert.NotEmpty(t, data.Diffs["Diff"].DiffText) + assert.Equal(t, oldDir, data.OldExtracted) + assert.Equal(t, newDir, data.NewExtracted) +} + +func TestRun_CreatesParentDirs(t *testing.T) { + root := t.TempDir() + oldDir := filepath.Join(root, "old") + newDir := filepath.Join(root, "new") + require.NoError(t, os.MkdirAll(oldDir, 0o750)) + require.NoError(t, os.MkdirAll(newDir, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(oldDir, "K.yaml"), []byte("{}"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newDir, "K.yaml"), []byte("{}"), 0o600)) + + reportPath := filepath.Join(root, "out", "sub", "report.txt") + jsonPath := filepath.Join(root, "out", "sub", "comparison.json") + _, err := Run(oldDir, newDir, reportPath, jsonPath) + require.NoError(t, err) + assert.FileExists(t, reportPath) + assert.FileExists(t, jsonPath) +} + +func TestRun_NoOverlap(t *testing.T) { + root := t.TempDir() + oldDir := filepath.Join(root, "old") + newDir := filepath.Join(root, "new") + require.NoError(t, os.MkdirAll(oldDir, 0o750)) + require.NoError(t, os.MkdirAll(newDir, 0o750)) + + reportPath := filepath.Join(root, "report.txt") + jsonPath := filepath.Join(root, "comparison.json") + summary, err := Run(oldDir, newDir, reportPath, jsonPath) + require.NoError(t, err) + assert.Contains(t, summary, "Only in old: 0") + assert.Contains(t, summary, "In both: 0") + + jsonData, err := os.ReadFile(jsonPath) + require.NoError(t, err) + var data ComparisonJSON + require.NoError(t, json.Unmarshal(jsonData, &data)) + assert.Empty(t, data.OnlyOld) + assert.Empty(t, data.InBoth) + assert.Empty(t, data.Diffs) +} + +func TestCRKeyFromObject(t *testing.T) { + cases := []struct { + name string + obj map[string]any + expected string + }{ + { + name: "namespaced resource", + obj: map[string]any{ + "kind": "ConfigMap", + "metadata": map[string]any{"name": "foo", "namespace": "openshift-monitoring"}, + }, + expected: "ConfigMap_openshift-monitoring_foo", + }, + { + name: "cluster-scoped resource", + obj: map[string]any{ + "kind": "ImageContentSourcePolicy", + "metadata": map[string]any{"name": "my-icsp"}, + }, + expected: "ImageContentSourcePolicy_my-icsp", + }, + { + name: "empty namespace is cluster-scoped", + obj: map[string]any{ + "kind": "Resource", + "metadata": map[string]any{"name": "x", "namespace": ""}, + }, + expected: "Resource_x", + }, + { + name: "sanitizes slashes", + obj: map[string]any{ + "kind": "Something", + "metadata": map[string]any{"name": "a/b/c", "namespace": "ns"}, + }, + expected: "Something_ns_a-b-c", + }, + { + name: "missing kind uses Unknown", + obj: map[string]any{"metadata": map[string]any{"name": "n", "namespace": "ns"}}, + expected: "Unknown_ns_n", + }, + { + name: "missing name uses unnamed", + obj: map[string]any{"kind": "ConfigMap", "metadata": map[string]any{"namespace": "ns"}}, + expected: "ConfigMap_ns_unnamed", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, CRKeyFromObject(tc.obj)) + }) + } +} + +func TestExtractCRsFromPolicyDoc_MinimalPolicy(t *testing.T) { + var doc map[string]any + require.NoError(t, sigsyaml.Unmarshal([]byte(minimalPolicyYAML), &doc)) + + result := ExtractCRsFromPolicyDoc(doc) + require.Len(t, result, 2) + keys := make([]string, len(result)) + for i, r := range result { + keys[i] = r.Key + } + assert.Contains(t, keys, "ConfigMap_openshift-monitoring_my-config") + assert.Contains(t, keys, "ImageContentSourcePolicy_my-icsp") +} + +func TestExtractCRsFromPolicyDoc_NonPolicy(t *testing.T) { + doc := map[string]any{"kind": "ConfigMap", "metadata": map[string]any{}} + assert.Empty(t, ExtractCRsFromPolicyDoc(doc)) +} + +func TestExtractCRsFromPolicyDoc_NoPolicyTemplates(t *testing.T) { + doc := map[string]any{"kind": "Policy", "spec": map[string]any{}} + assert.Empty(t, ExtractCRsFromPolicyDoc(doc)) +} + +func TestExtractCRs_OneFilePerCR(t *testing.T) { + root := t.TempDir() + generated := filepath.Join(root, "generated") + extracted := filepath.Join(root, "extracted") + require.NoError(t, os.MkdirAll(generated, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(generated, "policies.yaml"), []byte(minimalPolicyYAML), 0o600)) + + require.NoError(t, ExtractCRs(generated, extracted)) + + entries, err := os.ReadDir(extracted) + require.NoError(t, err) + names := make([]string, len(entries)) + for i, e := range entries { + names[i] = e.Name() + } + assert.ElementsMatch(t, []string{ + "ConfigMap_openshift-monitoring_my-config.yaml", + "ImageContentSourcePolicy_my-icsp.yaml", + }, names) +} + +func TestExtractCRs_ValidContent(t *testing.T) { + root := t.TempDir() + generated := filepath.Join(root, "generated") + extracted := filepath.Join(root, "extracted") + require.NoError(t, os.MkdirAll(generated, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(generated, "p.yaml"), []byte(minimalPolicyYAML), 0o600)) + + require.NoError(t, ExtractCRs(generated, extracted)) + + data, err := os.ReadFile(filepath.Join(extracted, "ConfigMap_openshift-monitoring_my-config.yaml")) + require.NoError(t, err) + var obj map[string]any + require.NoError(t, sigsyaml.Unmarshal(data, &obj)) + assert.Equal(t, "ConfigMap", obj["kind"]) + meta, ok := obj["metadata"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "my-config", meta["name"]) + assert.Equal(t, "openshift-monitoring", meta["namespace"]) + dataMap, ok := obj["data"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "value", dataMap["key"]) +} + +func TestExtractCRs_DuplicateKeyLastWins(t *testing.T) { + root := t.TempDir() + generated := filepath.Join(root, "generated") + extracted := filepath.Join(root, "extracted") + require.NoError(t, os.MkdirAll(generated, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(generated, "dup.yaml"), []byte(policyWithDuplicateKey), 0o600)) + + require.NoError(t, ExtractCRs(generated, extracted)) + + data, err := os.ReadFile(filepath.Join(extracted, "ConfigMap_ns1_same-name.yaml")) + require.NoError(t, err) + var obj map[string]any + require.NoError(t, sigsyaml.Unmarshal(data, &obj)) + dataMap, ok := obj["data"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "b", dataMap["second"]) + assert.NotContains(t, dataMap, "first") +} + +func TestExtractCRs_CreatesExtractedDir(t *testing.T) { + root := t.TempDir() + generated := filepath.Join(root, "generated") + extracted := filepath.Join(root, "extracted") + require.NoError(t, os.MkdirAll(generated, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(generated, "p.yaml"), []byte(minimalPolicyYAML), 0o600)) + + _, err := os.Stat(extracted) + require.True(t, os.IsNotExist(err)) + + require.NoError(t, ExtractCRs(generated, extracted)) + + info, err := os.Stat(extracted) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestRunCompare_ExtractsAndWritesArtifacts(t *testing.T) { + root := t.TempDir() + oldGen := filepath.Join(root, "old-gen") + newGen := filepath.Join(root, "new-gen") + require.NoError(t, os.MkdirAll(oldGen, 0o750)) + require.NoError(t, os.MkdirAll(newGen, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(oldGen, "p.yaml"), []byte(minimalPolicyYAML), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(newGen, "p.yaml"), []byte(minimalPolicyYAML), 0o600)) + + sessionDir := filepath.Join(root, "session") + require.NoError(t, os.MkdirAll(sessionDir, 0o750)) + + summary, err := RunCompare(oldGen, newGen, sessionDir) + require.NoError(t, err) + assert.Contains(t, summary, "Comparison summary") + assert.FileExists(t, filepath.Join(sessionDir, "diff-report.txt")) + assert.FileExists(t, filepath.Join(sessionDir, "comparison.json")) +} + +func TestConstants(t *testing.T) { + assert.Equal(t, "argocd/example/acmpolicygenerator", CRSPath) + assert.Equal(t, "source-crs", SourceCRSPath) + assert.Contains(t, ListOfCRsForSNO, "acm-common-ranGen.yaml") + assert.Len(t, ListOfCRsForSNO, 4) +} diff --git a/pkg/rdsdiff/constants.go b/pkg/rdsdiff/constants.go new file mode 100644 index 00000000..869e51fd --- /dev/null +++ b/pkg/rdsdiff/constants.go @@ -0,0 +1,19 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +// Paths under the telco-reference configuration root (e.g. telco-ran/configuration). +const ( + // CRSPath is the directory containing PolicyGenerator policy YAMLs and source-crs copy. + // Matches telco-reference: configuration/argocd/example/acmpolicygenerator + CRSPath = "argocd/example/acmpolicygenerator" + SourceCRSPath = "source-crs" +) + +// ListOfCRsForSNO is the list of SNO policy filenames used by PolicyGenerator. +var ListOfCRsForSNO = []string{ + "acm-common-ranGen.yaml", + "acm-example-sno-site.yaml", + "acm-group-du-sno-ranGen.yaml", + "acm-group-du-sno-validator-ranGen.yaml", +} diff --git a/pkg/rdsdiff/download.go b/pkg/rdsdiff/download.go new file mode 100644 index 00000000..d2135055 --- /dev/null +++ b/pkg/rdsdiff/download.go @@ -0,0 +1,398 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" +) + +// GitHubTreeURL holds parsed parts of a GitHub tree URL. +// Format: https://github.com///tree// +type GitHubTreeURL struct { + Owner string + Repo string + Branch string + Path string // e.g. "telco-ran/configuration" +} + +// DefaultDownloadTimeout is the HTTP client timeout for fetching archives. +const DefaultDownloadTimeout = 5 * time.Minute + +// MaxExtractedFileSize is the maximum size in bytes for a single extracted file (500 MB). +const MaxExtractedFileSize = 500 * 1024 * 1024 + +// safeJoinPath joins destDir and name and returns an error if the result escapes destDir (zip-slip prevention). +func safeJoinPath(destDir, name string) (string, error) { + clean := filepath.Clean(name) + if clean == "" || clean == "." { + return destDir, nil + } + if filepath.IsAbs(clean) || strings.Contains(clean, "..") { + return "", fmt.Errorf("unsafe path in archive: %s", name) + } + joined := filepath.Join(destDir, filepath.FromSlash(clean)) + rel, err := filepath.Rel(destDir, joined) + if err != nil || strings.HasPrefix(rel, "..") { + return "", fmt.Errorf("path escapes destination: %s", name) + } + return joined, nil +} + +// ParseGitHubTreeURL parses a full GitHub tree URL and returns owner, repo, branch, path. +// Path may be empty if URL points to repo root. Returns error for non-GitHub or malformed URLs. +// +// Format: https://github.com///tree// +func ParseGitHubTreeURL(raw string) (*GitHubTreeURL, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, errors.New("URL is required") + } + u, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + if u.Scheme != "https" || u.Host != "github.com" { + return nil, errors.New("only https://github.com/.../tree// URLs are supported") + } + path := strings.Trim(u.Path, "/") + parts := strings.SplitN(path, "/", 6) + // need at least: owner, repo, "tree", branch (4 segments) + if len(parts) < 4 || parts[2] != "tree" { + return nil, errors.New("URL must match https://github.com///tree//") + } + owner, repo, branch := parts[0], parts[1], parts[3] + subPath := "" + if len(parts) > 4 { + subPath = strings.Join(parts[4:], "/") + } + return &GitHubTreeURL{ + Owner: owner, + Repo: repo, + Branch: branch, + Path: strings.TrimSuffix(subPath, "/"), + }, nil +} + +// ArchiveURL returns the GitHub archive URL for this tree (branch only; path is applied after extract). +func (g *GitHubTreeURL) ArchiveURL() string { + return fmt.Sprintf("https://github.com/%s/%s/archive/refs/heads/%s.zip", g.Owner, g.Repo, g.Branch) +} + +// TopLevelDir is the single top-level directory inside the zip (GitHub uses -). +func (g *GitHubTreeURL) TopLevelDir() string { + return g.Repo + "-" + g.Branch +} + +// EffectiveRoot returns the path inside the extracted archive that is the configuration root. +// Empty Path means repo root, so EffectiveRoot == TopLevelDir. +func (g *GitHubTreeURL) EffectiveRoot(extractDir string) string { + if g.Path == "" { + return filepath.Join(extractDir, g.TopLevelDir()) + } + return filepath.Join(extractDir, g.TopLevelDir(), filepath.FromSlash(g.Path)) +} + +// DownloadGitHubTree downloads the GitHub archive for the given tree URL, extracts it to destDir, +// and returns the effective root path (extractDir/topLevel/path). destDir must exist. +func DownloadGitHubTree(treeURL, destDir string, client *http.Client) (string, error) { + parsed, err := ParseGitHubTreeURL(treeURL) + if err != nil { + return "", fmt.Errorf("parse github tree URL: %w", err) + } + if client == nil { + client = &http.Client{Timeout: DefaultDownloadTimeout} + } + archiveURL := parsed.ArchiveURL() + resp, err := client.Get(archiveURL) // nolint:noctx + if err != nil { + return "", fmt.Errorf("download %s: %w", archiveURL, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download %s: status %s", archiveURL, resp.Status) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read archive: %w", err) + } + if err := extractZipBytes(data, destDir); err != nil { + return "", fmt.Errorf("extract zip: %w", err) + } + return parsed.EffectiveRoot(destDir), nil +} + +// downloadDirectURL fetches a non-GitHub URL and extracts zip or tar.gz into destDir. +func downloadDirectURL(rawURL, destDir string, client *http.Client) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", fmt.Errorf("invalid URL: %w", err) + } + if u.Scheme != "https" && u.Scheme != "http" { + return "", fmt.Errorf("unsupported scheme %q (use https or http)", u.Scheme) + } + resp, err := client.Get(rawURL) // nolint:noctx + if err != nil { + return "", fmt.Errorf("download failed: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download failed: %s", resp.Status) + } + data, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + if len(data) == 0 { + return "", errors.New("download returned empty body") + } + ct := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Type"))) + pathLower := strings.ToLower(u.Path) + isZip := strings.Contains(ct, "zip") || strings.HasSuffix(pathLower, ".zip") + isGzip := strings.Contains(ct, "gzip") || strings.HasSuffix(pathLower, ".gz") || strings.HasSuffix(pathLower, ".tgz") + + if isZip { + effectiveRoot, err := extractZip(data, destDir) + if err != nil { + return "", fmt.Errorf("extract zip: %w", err) + } + return effectiveRoot, nil + } + if isGzip || (len(data) >= 2 && data[0] == 0x1f && data[1] == 0x8b) { + effectiveRoot, err := extractTarGz(data, destDir) + if err != nil { + return "", fmt.Errorf("extract tar.gz: %w", err) + } + return effectiveRoot, nil + } + if effectiveRoot, err := extractZip(data, destDir); err == nil { + return effectiveRoot, nil + } + if effectiveRoot, err := extractTarGz(data, destDir); err == nil { + return effectiveRoot, nil + } + return "", fmt.Errorf("unsupported archive format (expected zip or tar.gz); Content-Type: %q", ct) +} + +// DownloadURL downloads from any supported URL and extracts the archive to destDir. +// Returns the effective root path where extracted content lives (single top-level dir if present, else destDir). +// Supports: +// - GitHub tree URLs: https://github.com///tree// (fetches GitHub archive zip). +// - Direct URLs to a .zip or .tar.gz/.tgz archive; the response body is extracted as-is. +// +// destDir must exist. If the download or extraction fails, returns an error. +func DownloadURL(rawURL, destDir string, client *http.Client) (string, error) { + rawURL = strings.TrimSpace(rawURL) + if rawURL == "" { + return "", errors.New("URL is required") + } + if client == nil { + client = &http.Client{Timeout: DefaultDownloadTimeout} + } + if _, parseErr := ParseGitHubTreeURL(rawURL); parseErr == nil { + return DownloadGitHubTree(rawURL, destDir, client) + } + return downloadDirectURL(rawURL, destDir, client) +} + +// extractZipBytes extracts zip data into destDir without tracking top-level dirs. +func extractZipBytes(data []byte, destDir string) error { + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return fmt.Errorf("open zip: %w", err) + } + destDirClean := filepath.Clean(destDir) + for _, f := range zr.File { + name := f.Name + if f.FileInfo().IsDir() { + name = strings.TrimSuffix(name, "/") + if name == "" { + continue + } + dstPath, err := safeJoinPath(destDirClean, name) + if err != nil { + return err + } + if err := os.MkdirAll(dstPath, 0o750); err != nil { + return fmt.Errorf("mkdir %s: %w", dstPath, err) + } + continue + } + if err := extractZipFile(f, destDirClean); err != nil { + return err + } + } + return nil +} + +func extractZipFile(f *zip.File, destDirClean string) error { + dstPath, err := safeJoinPath(destDirClean, f.Name) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(dstPath), 0o750); err != nil { + return fmt.Errorf("mkdir parent of %s: %w", dstPath, err) + } + rc, err := f.Open() + if err != nil { + return fmt.Errorf("open zip entry %s: %w", f.Name, err) + } + defer rc.Close() + dst, err := os.Create(dstPath) // #nosec G304 -- dstPath from safeJoinPath + if err != nil { + return fmt.Errorf("create %s: %w", dstPath, err) + } + defer dst.Close() + if _, err = io.Copy(dst, io.LimitReader(rc, MaxExtractedFileSize)); err != nil { // #nosec G110 -- size bounded by MaxExtractedFileSize + return fmt.Errorf("copy %s: %w", f.Name, err) + } + return nil +} + +// extractZip extracts zip data into destDir and returns the effective root +// (single top-level directory if present, else destDir). +func extractZip(data []byte, destDir string) (string, error) { + zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return "", fmt.Errorf("open zip: %w", err) + } + destDirClean := filepath.Clean(destDir) + topLevel := detectZipTopLevel(zr) + for _, f := range zr.File { + if f.FileInfo().IsDir() { + name := strings.TrimSuffix(f.Name, "/") + if name == "" { + continue + } + dstPath, err := safeJoinPath(destDirClean, name) + if err != nil { + return "", err + } + if err := os.MkdirAll(dstPath, 0o750); err != nil { + return "", fmt.Errorf("mkdir %s: %w", dstPath, err) + } + continue + } + if err := extractZipFile(f, destDirClean); err != nil { + return "", err + } + } + if topLevel != "" { + return filepath.Join(destDir, topLevel), nil + } + return destDir, nil +} + +// detectZipTopLevel returns the single top-level directory name if all entries share one, +// or empty string if there are multiple top-level entries. +func detectZipTopLevel(zr *zip.Reader) string { + var topLevel string + for _, f := range zr.File { + name := f.Name + if f.FileInfo().IsDir() { + name = strings.TrimSuffix(name, "/") + if name == "" { + continue + } + } + first := strings.SplitN(name, "/", 2)[0] + if topLevel == "" { + topLevel = first + } else if topLevel != first { + return "" + } + } + return topLevel +} + +// extractTarGz extracts gzip-compressed tar data into destDir and returns the effective root. +func extractTarGz(data []byte, destDir string) (string, error) { + gr, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return "", fmt.Errorf("open gzip: %w", err) + } + defer gr.Close() + destDirClean := filepath.Clean(destDir) + tr := tar.NewReader(gr) + var topLevel string + for { + header, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return "", fmt.Errorf("read tar: %w", err) + } + name := strings.TrimPrefix(filepath.Clean(header.Name), ".") + name = strings.TrimPrefix(name, "/") + if name == "" { + continue + } + dstPath, err := safeJoinPath(destDirClean, name) + if err != nil { + return "", err + } + first := strings.SplitN(name, "/", 2)[0] + if topLevel == "" { + topLevel = first + } else if topLevel != first { + topLevel = "" + } + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(dstPath, 0o750); err != nil { + return "", fmt.Errorf("mkdir %s: %w", dstPath, err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(dstPath), 0o750); err != nil { + return "", fmt.Errorf("mkdir parent of %s: %w", dstPath, err) + } + outFile, err := os.Create(dstPath) // #nosec G304 -- dstPath from safeJoinPath + if err != nil { + return "", fmt.Errorf("create %s: %w", dstPath, err) + } + if _, err = io.Copy(outFile, io.LimitReader(tr, MaxExtractedFileSize)); err != nil { // #nosec G110 -- size bounded by MaxExtractedFileSize + outFile.Close() + return "", fmt.Errorf("copy %s: %w", name, err) + } + outFile.Close() + default: + // skip symlinks, etc. + } + } + if topLevel != "" { + return filepath.Join(destDir, topLevel), nil + } + return destDir, nil +} + +// ValidateConfigurationRoot checks that source-crs and CRSPath exist under root. +func ValidateConfigurationRoot(root string) error { + sourceCRS := filepath.Join(root, SourceCRSPath) + crsPath := filepath.Join(root, CRSPath) + if _, err := os.Stat(sourceCRS); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("configuration root missing %s (path: %s)", SourceCRSPath, sourceCRS) + } + return fmt.Errorf("stat %s: %w", sourceCRS, err) + } + if _, err := os.Stat(crsPath); err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("configuration root missing %s (path: %s)", CRSPath, crsPath) + } + return fmt.Errorf("stat %s: %w", crsPath, err) + } + return nil +} diff --git a/pkg/rdsdiff/download_test.go b/pkg/rdsdiff/download_test.go new file mode 100644 index 00000000..19cd867c --- /dev/null +++ b/pkg/rdsdiff/download_test.go @@ -0,0 +1,106 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGitHubTreeURL_FullURL(t *testing.T) { + u, err := ParseGitHubTreeURL("https://github.com/openshift-kni/telco-reference/tree/konflux-telco-core-rds-4-20/telco-ran/configuration") + require.NoError(t, err) + assert.Equal(t, "openshift-kni", u.Owner) + assert.Equal(t, "telco-reference", u.Repo) + assert.Equal(t, "konflux-telco-core-rds-4-20", u.Branch) + assert.Equal(t, "telco-ran/configuration", u.Path) + assert.Equal(t, "https://github.com/openshift-kni/telco-reference/archive/refs/heads/konflux-telco-core-rds-4-20.zip", u.ArchiveURL()) + assert.Equal(t, "telco-reference-konflux-telco-core-rds-4-20", u.TopLevelDir()) +} + +func TestParseGitHubTreeURL_BranchOnly(t *testing.T) { + u, err := ParseGitHubTreeURL("https://github.com/org/repo/tree/main") + require.NoError(t, err) + assert.Equal(t, "org", u.Owner) + assert.Equal(t, "repo", u.Repo) + assert.Equal(t, "main", u.Branch) + assert.Empty(t, u.Path) +} + +func TestParseGitHubTreeURL_Errors(t *testing.T) { + cases := []struct { + name string + url string + }{ + {"empty URL", ""}, + {"non-GitHub URL", "https://gitlab.com/org/repo/-/tree/branch/path"}, + {"URL without tree", "https://github.com/org/repo"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := ParseGitHubTreeURL(tc.url) + assert.Error(t, err) + }) + } +} + +func TestValidateConfigurationRoot_MissingSourceCRS(t *testing.T) { + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, CRSPath), 0o750)) + err := ValidateConfigurationRoot(root) + require.Error(t, err) + assert.Contains(t, err.Error(), "source-crs") +} + +func TestValidateConfigurationRoot_MissingCRSPath(t *testing.T) { + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, SourceCRSPath), 0o750)) + err := ValidateConfigurationRoot(root) + require.Error(t, err) + assert.Contains(t, err.Error(), CRSPath) +} + +func TestValidateConfigurationRoot_Success(t *testing.T) { + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, SourceCRSPath), 0o750)) + require.NoError(t, os.MkdirAll(filepath.Join(root, CRSPath), 0o750)) + require.NoError(t, ValidateConfigurationRoot(root)) +} + +func TestDownloadURL_EmptyURL(t *testing.T) { + dir := t.TempDir() + _, err := DownloadURL("", dir, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "URL is required") +} + +func TestDownloadURL_UnsupportedScheme(t *testing.T) { + dir := t.TempDir() + _, err := DownloadURL("ftp://example.com/archive.zip", dir, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported scheme") +} + +func TestSessionDir_CreatesUnderWorkDir(t *testing.T) { + workDir := t.TempDir() + sessionDir, err := SessionDir(workDir, "req-1") + require.NoError(t, err) + assert.True(t, len(sessionDir) > len(workDir)) + assert.Contains(t, sessionDir, "rds-diff-req-1") + + info, err := os.Stat(sessionDir) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestSessionDir_UsesTempWhenEmpty(t *testing.T) { + sessionDir, err := SessionDir("", "req-2") + require.NoError(t, err) + defer os.RemoveAll(sessionDir) + assert.NotEmpty(t, sessionDir) + assert.Contains(t, sessionDir, "rds-diff-req-2") +} diff --git a/pkg/rdsdiff/fileio.go b/pkg/rdsdiff/fileio.go new file mode 100644 index 00000000..370df46e --- /dev/null +++ b/pkg/rdsdiff/fileio.go @@ -0,0 +1,35 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "fmt" + "os" + "path/filepath" +) + +// readCleanFile reads a file after cleaning the path to prevent path traversal. +func readCleanFile(name string) ([]byte, error) { + data, err := os.ReadFile(filepath.Clean(name)) // #nosec G304 -- path cleaned + if err != nil { + return nil, fmt.Errorf("read %s: %w", name, err) + } + return data, nil +} + +// writeCleanFile writes data to a file after cleaning the path to prevent path traversal. +func writeCleanFile(name string, data []byte) error { + if err := os.WriteFile(filepath.Clean(name), data, 0o600); err != nil { // #nosec G304 -- path cleaned + return fmt.Errorf("write %s: %w", name, err) + } + return nil +} + +// readCleanDir reads a directory after cleaning the path to prevent path traversal. +func readCleanDir(name string) ([]os.DirEntry, error) { + entries, err := os.ReadDir(filepath.Clean(name)) + if err != nil { + return nil, fmt.Errorf("readdir %s: %w", name, err) + } + return entries, nil +} diff --git a/pkg/rdsdiff/policygen.go b/pkg/rdsdiff/policygen.go new file mode 100644 index 00000000..19d8b636 --- /dev/null +++ b/pkg/rdsdiff/policygen.go @@ -0,0 +1,92 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +// RunPolicyGen copies source-crs into CRSPath under configRoot, runs the PolicyGenerator binary +// for each SNO policy file, and writes generated YAMLs to generatedDir. +// configRoot is the effective configuration root (e.g. .../telco-ran/configuration). +// policyGeneratorPath is the path to the PolicyGenerator binary. +func RunPolicyGen(configRoot, policyGeneratorPath, generatedDir string) error { + configRoot = filepath.Clean(configRoot) + generatedDir = filepath.Clean(generatedDir) + sourceCRS := filepath.Join(configRoot, SourceCRSPath) + crsPathDir := filepath.Join(configRoot, CRSPath) + if err := os.MkdirAll(generatedDir, 0o750); err != nil { + return fmt.Errorf("mkdir generated dir: %w", err) + } + // Copy source-crs into CRSPath so PolicyGenerator finds it + destSourceCRS := filepath.Join(crsPathDir, "source-crs") + if err := os.RemoveAll(destSourceCRS); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("remove dest source-crs: %w", err) + } + if err := copyDir(sourceCRS, destSourceCRS); err != nil { + return fmt.Errorf("copy source-crs to %s: %w", destSourceCRS, err) + } + for _, policyFile := range ListOfCRsForSNO { + policyPath := filepath.Join(crsPathDir, policyFile) + if _, err := os.Stat(policyPath); err != nil { + if os.IsNotExist(err) { + continue + } + return fmt.Errorf("stat %s: %w", policyPath, err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + cmd := exec.CommandContext(ctx, policyGeneratorPath, policyPath) // #nosec G204 -- paths from config root + cmd.Dir = crsPathDir + out, err := cmd.CombinedOutput() + cancel() + if err != nil { + return fmt.Errorf("policyGenerator %s: %w\n%s", policyFile, err, string(out)) + } + base := strings.TrimSuffix(policyFile, filepath.Ext(policyFile)) + outPath := filepath.Join(generatedDir, base+"-generated.yaml") + if err := writeCleanFile(outPath, out); err != nil { + return fmt.Errorf("write PolicyGenerator output %s: %w", outPath, err) + } + } + return nil +} + +func copyDir(src, dst string) error { + entries, err := readCleanDir(src) + if err != nil { + return fmt.Errorf("read dir %s: %w", src, err) + } + if err := os.MkdirAll(dst, 0o750); err != nil { + return fmt.Errorf("mkdir %s: %w", dst, err) + } + for _, e := range entries { + srcPath, err := safeJoinPath(src, e.Name()) + if err != nil { + return fmt.Errorf("unsafe entry in %s: %w", src, err) + } + dstPath, err := safeJoinPath(dst, e.Name()) + if err != nil { + return fmt.Errorf("unsafe entry in %s: %w", dst, err) + } + if e.IsDir() { + if err := copyDir(srcPath, dstPath); err != nil { + return err + } + continue + } + data, err := readCleanFile(srcPath) + if err != nil { + return fmt.Errorf("read %s: %w", srcPath, err) + } + if err := writeCleanFile(dstPath, data); err != nil { + return fmt.Errorf("write %s: %w", dstPath, err) + } + } + return nil +} diff --git a/pkg/rdsdiff/storage.go b/pkg/rdsdiff/storage.go new file mode 100644 index 00000000..b2ff1be9 --- /dev/null +++ b/pkg/rdsdiff/storage.go @@ -0,0 +1,78 @@ +// SPDX-License-Identifier:Apache-2.0 + +package rdsdiff + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +// SessionDir creates a session directory under workDir and returns its path. +// workDir defaults to os.TempDir() if empty. Session name is rds-diff--. +func SessionDir(workDir, requestID string) (string, error) { + if workDir == "" { + workDir = os.TempDir() + } + workDir = filepath.Clean(workDir) + if err := os.MkdirAll(workDir, 0o750); err != nil { + return "", fmt.Errorf("mkdir work dir: %w", err) + } + name := fmt.Sprintf("rds-diff-%s-%d", requestID, time.Now().Unix()) + sessionPath := filepath.Join(workDir, name) + if err := os.MkdirAll(sessionPath, 0o750); err != nil { + return "", fmt.Errorf("mkdir session: %w", err) + } + return sessionPath, nil +} + +// SessionID returns the session directory name (last path component) for use as a stable artifact ID. +func SessionID(sessionPath string) string { + return filepath.Base(sessionPath) +} + +// ErrInvalidSessionID is returned when the session ID is invalid (e.g. path traversal). +var ErrInvalidSessionID = errors.New("invalid session id") + +// ResolveSessionPath resolves workDir + sessionID to an absolute session path and validates that it +// is under workDir and exists as a directory. sessionID must be a single path segment (no slashes or ".."). +func ResolveSessionPath(workDir, sessionID string) (string, error) { + workDir = filepath.Clean(workDir) + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return "", ErrInvalidSessionID + } + if strings.Contains(sessionID, "..") || filepath.IsAbs(sessionID) || strings.ContainsAny(sessionID, `/\`) { + return "", ErrInvalidSessionID + } + sessionPath := filepath.Join(workDir, sessionID) + absWork, err := filepath.Abs(workDir) + if err != nil { + return "", fmt.Errorf("resolve work dir: %w", err) + } + absSession, err := filepath.Abs(sessionPath) + if err != nil { + return "", fmt.Errorf("resolve session path: %w", err) + } + rel, err := filepath.Rel(absWork, absSession) + if err != nil { + return "", fmt.Errorf("resolve session path: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", ErrInvalidSessionID + } + info, err := os.Stat(sessionPath) + if err != nil { + if os.IsNotExist(err) { + return "", fmt.Errorf("session not found: %w", err) + } + return "", fmt.Errorf("stat session: %w", err) + } + if !info.IsDir() { + return "", fmt.Errorf("not a directory: %s", sessionPath) + } + return sessionPath, nil +}