diff --git a/code/rewriter.go b/code/rewriter.go index 0fabc5a..dee2218 100644 --- a/code/rewriter.go +++ b/code/rewriter.go @@ -594,6 +594,10 @@ func (r *Rewriter) rewriteFuncDecl(fn *ast.FuncDecl) error { // RewriteFile rewrites a single file func (r *Rewriter) RewriteFile(path string) (err error) { + // Reset state up front so previous-file result does not leak into + // files that have nothing to rewrite (for example, doc.go). + r.rewritten = false + defer func() { if e := recover(); e != nil { err = fmt.Errorf("%s %v\n%s", r.currentPath, e, debug.Stack()) @@ -610,7 +614,6 @@ func (r *Rewriter) RewriteFile(path string) (err error) { r.currentPath = path r.currentFile = file r.currsetFset = fset - r.rewritten = false var failpointImport *ast.ImportSpec for _, imp := range file.Imports { diff --git a/code/rewriter_state_test.go b/code/rewriter_state_test.go new file mode 100644 index 0000000..3e241eb --- /dev/null +++ b/code/rewriter_state_test.go @@ -0,0 +1,44 @@ +package code_test + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/pingcap/failpoint/code" + "github.com/stretchr/testify/require" +) + +func TestRewriteFileResetRewrittenStateForNoDeclFile(t *testing.T) { + t.Parallel() + + workDir := t.TempDir() + failpointFile := filepath.Join(workDir, "with_failpoint.go") + docFile := filepath.Join(workDir, "doc.go") + + require.NoError(t, os.WriteFile(failpointFile, []byte(`package sample + +import "github.com/pingcap/failpoint" + +func f() { + failpoint.Inject("fp", func() {}) +} +`), 0o644)) + require.NoError(t, os.WriteFile(docFile, []byte("package sample\n"), 0o644)) + + rewriter := code.NewRewriter(workDir) + rewriter.SetAllowNotChecked(true) + + var out bytes.Buffer + + rewriter.SetOutput(&out) + require.NoError(t, rewriter.RewriteFile(failpointFile)) + require.True(t, rewriter.GetRewritten()) + + out.Reset() + rewriter.SetOutput(&out) + require.NoError(t, rewriter.RewriteFile(docFile)) + require.False(t, rewriter.GetRewritten()) + require.Zero(t, out.Len()) +}