diff --git a/script.go b/script.go index cb5b00f..414fa32 100644 --- a/script.go +++ b/script.go @@ -340,6 +340,7 @@ func CompareCmd(db *DB) script.Cmd { } tableName := args[0] + fileName := args[1] txn := db.ReadTxn() meta := db.GetTable(txn, tableName) @@ -349,16 +350,16 @@ func CompareCmd(db *DB) script.Cmd { tbl := AnyTable{Meta: meta} header := tbl.TableHeader() - data, err := os.ReadFile(s.Path(args[1])) + data, err := os.ReadFile(s.Path(fileName)) if err != nil { - return nil, fmt.Errorf("ReadFile(%s): %w", args[1], err) + return nil, fmt.Errorf("ReadFile(%s): %w", fileName, err) } lines := strings.Split(s.ExpandEnv(string(data), false), "\n") lines = slices.DeleteFunc(lines, func(line string) bool { return strings.TrimSpace(line) == "" }) if len(lines) < 1 { - return nil, fmt.Errorf("%q missing header line, e.g. %q", args[1], strings.Join(header, " ")) + return nil, fmt.Errorf("%q missing header line, e.g. %q", fileName, strings.Join(header, " ")) } columnNames, columnPositions := splitHeaderLine(lines[0]) @@ -370,6 +371,7 @@ func CompareCmd(db *DB) script.Cmd { lines = lines[1:] origLines := lines timeoutChan := time.After(timeout) + var lastActual string for { lines = origLines @@ -378,7 +380,10 @@ func CompareCmd(db *DB) script.Cmd { equal := true var diff bytes.Buffer w := newTabWriter(&diff) - fmt.Fprintf(w, " %s\n", joinByPositions(columnNames, columnPositions, false)) + joined := joinByPositions(columnNames, columnPositions, false) + fmt.Fprintf(w, " %s\n", joined) + var actual strings.Builder + fmt.Fprintf(&actual, " %s\n", joined) objs, watch := tbl.AllWatch(db.ReadTxn()) for obj := range objs { @@ -387,6 +392,7 @@ func CompareCmd(db *DB) script.Cmd { if grepRe != nil && !grepRe.Match([]byte(row)) { continue } + fmt.Fprintf(&actual, "%s\n", row) if len(lines) == 0 { equal = false @@ -409,6 +415,7 @@ func CompareCmd(db *DB) script.Cmd { fmt.Fprintf(w, "+ %s\n", line) equal = false } + lastActual = actual.String() if equal { return nil, nil } @@ -419,6 +426,10 @@ func CompareCmd(db *DB) script.Cmd { return nil, s.Context().Err() case <-timeoutChan: + if s.DoUpdate { + s.FileUpdates[fileName] = lastActual + return nil, nil + } return nil, fmt.Errorf("table mismatch:\n%s", diff.String()) case <-watch: