Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion code/rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ func (r *Rewriter) rewriteFuncDecl(fn *ast.FuncDecl) error {

// RewriteFile rewrites a single file
func (r *Rewriter) RewriteFile(path string) (err error) {
// Reset state up front so previous-file result does not leak into
// files that have nothing to rewrite (for example, doc.go).
r.rewritten = false

defer func() {
if e := recover(); e != nil {
err = fmt.Errorf("%s %v\n%s", r.currentPath, e, debug.Stack())
Expand All @@ -610,7 +614,6 @@ func (r *Rewriter) RewriteFile(path string) (err error) {
r.currentPath = path
r.currentFile = file
r.currsetFset = fset
r.rewritten = false

var failpointImport *ast.ImportSpec
for _, imp := range file.Imports {
Expand Down
44 changes: 44 additions & 0 deletions code/rewriter_state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package code_test

import (
"bytes"
"os"
"path/filepath"
"testing"

"github.com/pingcap/failpoint/code"
"github.com/stretchr/testify/require"
)

func TestRewriteFileResetRewrittenStateForNoDeclFile(t *testing.T) {
t.Parallel()

workDir := t.TempDir()
failpointFile := filepath.Join(workDir, "with_failpoint.go")
docFile := filepath.Join(workDir, "doc.go")

require.NoError(t, os.WriteFile(failpointFile, []byte(`package sample

import "github.com/pingcap/failpoint"

func f() {
failpoint.Inject("fp", func() {})
}
`), 0o644))
require.NoError(t, os.WriteFile(docFile, []byte("package sample\n"), 0o644))

rewriter := code.NewRewriter(workDir)
rewriter.SetAllowNotChecked(true)

var out bytes.Buffer

rewriter.SetOutput(&out)
require.NoError(t, rewriter.RewriteFile(failpointFile))
require.True(t, rewriter.GetRewritten())

out.Reset()
rewriter.SetOutput(&out)
require.NoError(t, rewriter.RewriteFile(docFile))
require.False(t, rewriter.GetRewritten())
require.Zero(t, out.Len())
}
Loading