diff --git a/.golangci.yml b/.golangci.yml index 55914cb..0770b98 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -13,9 +13,14 @@ linters: - misspell - errcheck - gocritic + - gocyclo - gosec - nolintlint settings: + gocyclo: + # goreportcard flags any function with cyclomatic complexity > 15. + # Mirror that threshold here so the local gate, CI, and the badge agree. + min-complexity: 16 errcheck: exclude-functions: - (*database/sql.Stmt).Close @@ -45,6 +50,7 @@ linters: - unused - errcheck - gosec + - gocyclo paths: - testbench - testdata diff --git a/indexer/index.go b/indexer/index.go index 2b6fefc..ab5c152 100644 --- a/indexer/index.go +++ b/indexer/index.go @@ -255,65 +255,88 @@ func (r *indexRun) indexCallGraph() { if pkg.Types == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { continue } - fileByPos := buildFileByPos(pkg) + r.indexPackageCallGraph(pkg) + } +} - onCall := func(from, to, filePath, expr string, line, col int) { - if _, err := r.insertCall.Exec( - r.moduleRoot, pkg.PkgPath, filePath, line, col, from, to, expr, "call", - ); err == nil { - r.callCount++ - } - } - onUnresolved := func(from, filePath, expr string, line, col int) { - if _, err := r.insertUnresolved.Exec( - r.moduleRoot, pkg.PkgPath, filePath, line, col, from, expr, "unresolved", - ); err == nil { - r.unresolvedCount++ - } - } +func (r *indexRun) indexPackageCallGraph(pkg *packages.Package) { + fileByPos := buildFileByPos(pkg) + onCall := r.makeCallEdgeRecorder(pkg) + onUnresolved := r.makeUnresolvedEdgeRecorder(pkg) + for _, f := range pkg.Syntax { + filePath := fileByPos[f] + scanFuncBodies(pkg, f, filePath, func(from, fp string, body ast.Node) { + scanCallsInNode(pkg, from, fp, body, onCall, onUnresolved) + }) + scanCallGraphVarInits(pkg, f, filePath, onCall, onUnresolved) + } +} - for _, f := range pkg.Syntax { - filePath := fileByPos[f] +func (r *indexRun) makeCallEdgeRecorder(pkg *packages.Package) func(from, to, filePath, expr string, line, col int) { + return func(from, to, filePath, expr string, line, col int) { + if _, err := r.insertCall.Exec( + r.moduleRoot, pkg.PkgPath, filePath, line, col, from, to, expr, "call", + ); err == nil { + r.callCount++ + } + } +} - scanFuncBodies(pkg, f, filePath, func(from, fp string, body ast.Node) { - scanCallsInNode(pkg, from, fp, body, onCall, onUnresolved) - }) +func (r *indexRun) makeUnresolvedEdgeRecorder(pkg *packages.Package) func(from, filePath, expr string, line, col int) { + return func(from, filePath, expr string, line, col int) { + if _, err := r.insertUnresolved.Exec( + r.moduleRoot, pkg.PkgPath, filePath, line, col, from, expr, "unresolved", + ); err == nil { + r.unresolvedCount++ + } + } +} - // Package-scope var initializers need special handling for FuncLit bodies. - for _, decl := range f.Decls { - gd, ok := decl.(*ast.GenDecl) - if !ok || gd.Tok != token.VAR { - continue - } - for _, spec := range gd.Specs { - vs, ok := spec.(*ast.ValueSpec) - if !ok { - continue - } - owner := "var" - if len(vs.Names) > 0 { - owner = vs.Names[0].Name - } - for _, value := range vs.Values { - syntheticFrom := fmt.Sprintf("%s.init$var:%s", pkg.PkgPath, owner) - scanCallsInNode(pkg, syntheticFrom, filePath, value, onCall, onUnresolved) - - ast.Inspect(value, func(n ast.Node) bool { - lit, ok := n.(*ast.FuncLit) - if !ok || lit.Body == nil { - return true - } - from := funcLitFQName(pkg.PkgPath, owner, pkg.Fset.PositionFor(lit.Type.Func, false)) - scanCallsInNode(pkg, from, filePath, lit.Body, onCall, onUnresolved) - return false - }) - } - } +// scanCallGraphVarInits walks each package-scope var initializer, recording +// both the calls in the initializer expression itself and any calls inside +// nested FuncLit bodies (which need a synthetic from-fqname tied to position). +func scanCallGraphVarInits(pkg *packages.Package, f *ast.File, filePath string, + onCall func(from, to, filePath, expr string, line, col int), + onUnresolved func(from, filePath, expr string, line, col int), +) { + for _, decl := range f.Decls { + gd, ok := decl.(*ast.GenDecl) + if !ok || gd.Tok != token.VAR { + continue + } + for _, spec := range gd.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + owner := "var" + if len(vs.Names) > 0 { + owner = vs.Names[0].Name + } + for _, value := range vs.Values { + syntheticFrom := fmt.Sprintf("%s.init$var:%s", pkg.PkgPath, owner) + scanCallsInNode(pkg, syntheticFrom, filePath, value, onCall, onUnresolved) + scanFuncLitsInVarInit(pkg, owner, filePath, value, onCall, onUnresolved) } } } } +func scanFuncLitsInVarInit(pkg *packages.Package, owner, filePath string, value ast.Node, + onCall func(from, to, filePath, expr string, line, col int), + onUnresolved func(from, filePath, expr string, line, col int), +) { + ast.Inspect(value, func(n ast.Node) bool { + lit, ok := n.(*ast.FuncLit) + if !ok || lit.Body == nil { + return true + } + from := funcLitFQName(pkg.PkgPath, owner, pkg.Fset.PositionFor(lit.Type.Func, false)) + scanCallsInNode(pkg, from, filePath, lit.Body, onCall, onUnresolved) + return false + }) +} + // scanFuncBodies iterates function declarations in a file, calling fn with // the function's fqname, file path, and body for each. func scanFuncBodies(pkg *packages.Package, f *ast.File, filePath string, fn func(from, filePath string, body ast.Node)) { diff --git a/indexer/stale.go b/indexer/stale.go index e5518eb..6edc06e 100644 --- a/indexer/stale.go +++ b/indexer/stale.go @@ -103,102 +103,122 @@ func gitFastPath(moduleRoot, packagePath, indexedCommit string) (bool, error) { return false, nil } +type storedFile struct { + path string + hash string +} + // fileHashFallback checks staleness by comparing stored file lists and hashes // against the current state on disk. func fileHashFallback(db *sql.DB, moduleRoot, packagePath string) (bool, error) { + storedFiles, err := loadStoredFiles(db, moduleRoot, packagePath) + if err != nil { + return true, err + } + // Pre-item13 DB has no stored file rows — treat as stale to force a rebuild. + if len(storedFiles) == 0 { + return true, nil + } + if hasNewProductionFiles(moduleRoot, packagePath, storedSetByBase(storedFiles)) { + return true, nil + } + stale, complete := comparePerFileHashes(storedFiles) + if stale { + return true, nil + } + if complete { + return false, nil + } + return comparePackageHashFromDB(db, moduleRoot, packagePath, storedFiles) +} + +func loadStoredFiles(db *sql.DB, moduleRoot, packagePath string) ([]storedFile, error) { rows, err := db.Query( `SELECT file_path, file_hash FROM package_files WHERE module_root = ? AND package_path = ?`, moduleRoot, packagePath, ) if err != nil { - return true, fmt.Errorf("query package_files: %w", err) + return nil, fmt.Errorf("query package_files: %w", err) } defer rows.Close() - - type storedFile struct { - path string - hash string - } - var storedFiles []storedFile + var out []storedFile for rows.Next() { var sf storedFile if err := rows.Scan(&sf.path, &sf.hash); err != nil { - return true, err + return nil, err } - storedFiles = append(storedFiles, sf) - } - if err := rows.Err(); err != nil { - return true, err - } - - // No stored files means pre-item13 DB — treat as stale. - if len(storedFiles) == 0 { - return true, nil + out = append(out, sf) } + return out, rows.Err() +} - // Build a set of stored file basenames for new-file detection. - storedSet := make(map[string]bool, len(storedFiles)) - for _, sf := range storedFiles { - storedSet[filepath.Base(sf.path)] = true +func storedSetByBase(files []storedFile) map[string]bool { + s := make(map[string]bool, len(files)) + for _, f := range files { + s[filepath.Base(f.path)] = true } + return s +} - // Check for new non-test .go files on disk that weren't in the index. - // We compare by basename rather than raw count because the index may - // exclude _test.go files (withTests=false), and a naive count would - // produce false positives when test files exist on disk. +// hasNewProductionFiles returns true when the package directory contains a +// non-test .go file not present in the indexed file list. Resolution failures +// are treated as stale (conservative). +func hasNewProductionFiles(moduleRoot, packagePath string, indexed map[string]bool) bool { pkgDir, err := packageDir(moduleRoot, packagePath) if err != nil { - return true, nil + return true } diskFiles, err := listGoFiles(pkgDir) if err != nil { - return true, nil + return true } for _, df := range diskFiles { base := filepath.Base(df) - if !storedSet[base] && !strings.HasSuffix(base, "_test.go") { - return true, nil + if !indexed[base] && !strings.HasSuffix(base, "_test.go") { + return true } } + return false +} - // Compare per-file hashes from package_files. - for _, sf := range storedFiles { +// comparePerFileHashes returns (stale, complete). complete=false means a stored +// file lacked a hash and the per-file comparison was abandoned — caller should +// fall back to the package-level files_hash check (matches legacy behavior). +func comparePerFileHashes(files []storedFile) (stale, complete bool) { + for _, sf := range files { if sf.hash == "" { - // No hash stored — fall through to package-level files_hash check. - goto packageHash + return false, false } currentHash, err := hashFile(sf.path) - if err != nil { - return true, nil // file missing or unreadable → stale - } - if currentHash != sf.hash { - return true, nil + if err != nil || currentHash != sf.hash { + return true, true } } - return false, nil + return false, true +} -packageHash: - // Compare stored files_hash from package_meta. +func comparePackageHashFromDB(db *sql.DB, moduleRoot, packagePath string, files []storedFile) (bool, error) { var storedHash string - err = db.QueryRow( + err := db.QueryRow( `SELECT files_hash FROM package_meta WHERE module_root = ? AND package_path = ?`, moduleRoot, packagePath, ).Scan(&storedHash) if err != nil || storedHash == "" { return true, nil } + return computeAndCompareFilesHash(files, storedHash) +} - // Compute current hash over the stored file list (same paths as index time). - var storedPaths []string - for _, sf := range storedFiles { - storedPaths = append(storedPaths, sf.path) +func computeAndCompareFilesHash(files []storedFile, storedHash string) (bool, error) { + paths := make([]string, 0, len(files)) + for _, sf := range files { + paths = append(paths, sf.path) } - sort.Strings(storedPaths) - currentHash, err := ComputeFilesHash(storedPaths) + sort.Strings(paths) + currentHash, err := ComputeFilesHash(paths) if err != nil { return true, nil } - return currentHash != storedHash, nil } @@ -324,6 +344,52 @@ func hashFile(path string) (string, error) { // ---- store.ReadStore-backed variants (Phase 3) ---- +type gitFastPathState struct { + available bool + repoRoot string + changedFiles map[string]bool +} + +func loadGitFastPathState(indexedCommit, moduleRoot string) gitFastPathState { + if indexedCommit == "" { + return gitFastPathState{} + } + root, err := gitRoot(moduleRoot) + if err != nil { + return gitFastPathState{} + } + changed, err := gitChangedFiles(moduleRoot, indexedCommit) + if err != nil { + return gitFastPathState{} + } + return gitFastPathState{available: true, repoRoot: root, changedFiles: changed} +} + +// pkgStaleViaFastPath reports (handled, stale). When handled=false, caller must +// fall back to the file-hash check. +func pkgStaleViaFastPath(pkg store.PackageMetaRow, st gitFastPathState) (bool, bool) { + if !st.available { + return false, false + } + pkgDir, err := packageDir(pkg.ModuleRoot, pkg.PackagePath) + if err != nil { + return false, false + } + relPkg, err := filepath.Rel(st.repoRoot, pkgDir) + if err != nil { + return false, false + } + relPkg = filepath.ToSlash(relPkg) + prefix := relPkg + "/" + for f := range st.changedFiles { + f = filepath.ToSlash(f) + if f == relPkg || strings.HasPrefix(f, prefix) { + return true, true + } + } + return true, false +} + // StalePackagesStore is the store.ReadStore-backed version of StalePackages. // It runs the git diff once and checks all packages against the cached result, // avoiding per-package process spawns. @@ -336,51 +402,18 @@ func StalePackagesStore(rs store.ReadStore) ([]string, error) { if len(pkgs) == 0 { return nil, nil } - - // Read indexed_commit once (same for all packages). indexedCommit, _ := rs.IndexedCommit(ctx) - - // Run git diff once and cache the changed files set. - var changedFiles map[string]bool - var repoRoot string - gitFastPathAvailable := false - if indexedCommit != "" { - root, rootErr := gitRoot(pkgs[0].ModuleRoot) - if rootErr == nil { - changed, diffErr := gitChangedFiles(pkgs[0].ModuleRoot, indexedCommit) - if diffErr == nil { - changedFiles = changed - repoRoot = root - gitFastPathAvailable = true - } - } - } + fast := loadGitFastPathState(indexedCommit, pkgs[0].ModuleRoot) var stale []string for _, pkg := range pkgs { - if gitFastPathAvailable { - pkgDir, err := packageDir(pkg.ModuleRoot, pkg.PackagePath) - if err == nil { - relPkg, err := filepath.Rel(repoRoot, pkgDir) - if err == nil { - relPkg = filepath.ToSlash(relPkg) - prefix := relPkg + "/" - isStale := false - for f := range changedFiles { - f = filepath.ToSlash(f) - if f == relPkg || strings.HasPrefix(f, prefix) { - isStale = true - break - } - } - if isStale { - stale = append(stale, pkg.PackagePath) - } - continue - } + handled, isStale := pkgStaleViaFastPath(pkg, fast) + if handled { + if isStale { + stale = append(stale, pkg.PackagePath) } + continue } - // Fallback: per-package file-hash check. isStale, err := fileHashFallbackStore(rs, pkg.ModuleRoot, pkg.PackagePath) if err != nil || isStale { stale = append(stale, pkg.PackagePath) @@ -399,52 +432,51 @@ func fileHashFallbackStore(rs store.ReadStore, moduleRoot, packagePath string) ( if len(files) == 0 { return true, nil } - - storedSet := make(map[string]bool, len(files)) - for _, f := range files { - storedSet[filepath.Base(f.FilePath)] = true - } - - pkgDir, err := packageDir(moduleRoot, packagePath) - if err != nil { + if hasNewProductionFiles(moduleRoot, packagePath, packageFileSetByBase(files)) { return true, nil } - diskFiles, err := listGoFiles(pkgDir) - if err != nil { + stale, complete := comparePerFileHashesStore(files) + if stale { return true, nil } - for _, df := range diskFiles { - base := filepath.Base(df) - if !storedSet[base] && !strings.HasSuffix(base, "_test.go") { - return true, nil - } + if complete { + return false, nil } + return comparePackageHashFromStore(rs, ctx, moduleRoot, packagePath, files) +} + +func packageFileSetByBase(files []store.PackageFile) map[string]bool { + s := make(map[string]bool, len(files)) + for _, f := range files { + s[filepath.Base(f.FilePath)] = true + } + return s +} +func comparePerFileHashesStore(files []store.PackageFile) (stale, complete bool) { for _, f := range files { if f.FileHash == "" { - goto packageHash + return false, false } currentHash, err := hashFile(f.FilePath) - if err != nil { - return true, nil - } - if currentHash != f.FileHash { - return true, nil + if err != nil || currentHash != f.FileHash { + return true, true } } - return false, nil + return false, true +} -packageHash: +func comparePackageHashFromStore(rs store.ReadStore, ctx context.Context, moduleRoot, packagePath string, files []store.PackageFile) (bool, error) { storedHash, err := rs.StoredFilesHash(ctx, moduleRoot, packagePath) if err != nil || storedHash == "" { return true, nil } - var storedPaths []string + paths := make([]string, 0, len(files)) for _, f := range files { - storedPaths = append(storedPaths, f.FilePath) + paths = append(paths, f.FilePath) } - sort.Strings(storedPaths) - currentHash, err := ComputeFilesHash(storedPaths) + sort.Strings(paths) + currentHash, err := ComputeFilesHash(paths) if err != nil { return true, nil } diff --git a/internal/cmd/callees.go b/internal/cmd/callees.go index 556b5b7..45fc19c 100644 --- a/internal/cmd/callees.go +++ b/internal/cmd/callees.go @@ -131,7 +131,6 @@ func execCallees(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy if strings.TrimSpace(symbol) == "" { return errors.New("--symbol is required") } - if autoReindex { checkAndAutoReindex(db, false, false) } @@ -140,13 +139,7 @@ func execCallees(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy symbol, resolveNote = resolveSymbolInput(db, symbol, pkg) ctx := context.Background() - opts := store.CalleesOpts{ - Symbol: symbol, - Fuzzy: fuzzy, - Pkg: pkg, - Unique: unique, - Limit: limit, - } + opts := store.CalleesOpts{Symbol: symbol, Fuzzy: fuzzy, Pkg: pkg, Unique: unique, Limit: limit} if countOnly { n, err := rs.CountCallees(ctx, opts) @@ -157,11 +150,38 @@ func execCallees(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy return nil } - callees := make([]calleeRow, 0) - storeRows, err := rs.DirectCallees(ctx, opts) + callees, err := collectCallees(rs, ctx, opts) if err != nil { return err } + unresolved := []calleesUnresolvedRow{} + if includeUnresolved { + urows, err := queryUnresolvedCallees(db, symbol, fuzzy, unique, limit) + if err != nil { + return err + } + if urows != nil { + unresolved = urows + } + } + + if asJSON { + return emitCalleesJSON(db, dbPath, symbol, resolveNote, callees, unresolved) + } + printCallees(symbol, callees, unique) + printCalleesUnresolved(symbol, unresolved, unique) + if resolveNote != "" { + fmt.Printf("# %s\n", resolveNote) + } + return nil +} + +func collectCallees(rs store.ReadStore, ctx context.Context, opts store.CalleesOpts) ([]calleeRow, error) { + storeRows, err := rs.DirectCallees(ctx, opts) + if err != nil { + return nil, err + } + out := make([]calleeRow, 0, len(storeRows)) for _, sr := range storeRows { r := calleeRow{ To: sr.FQName, @@ -172,62 +192,52 @@ func execCallees(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy PackagePath: sr.PackagePath, } r.Name = shortName(r.To) - if asJSON { - callees = append(callees, r) - } else { - if unique { - fmt.Printf("%s [%s]\n", r.To, r.Kind) - } else { - marker := "" - if r.Kind == "ref" { - marker = " [func-ref]" - } - fmt.Printf("%s -> %s\t%s:%d:%d%s\n", symbol, r.To, r.File, r.Line, r.Col, marker) - } - } + out = append(out, r) } + return out, nil +} - unresolved := make([]calleesUnresolvedRow, 0) - if includeUnresolved { - urows, err := queryUnresolvedCallees(db, symbol, fuzzy, unique, limit) - if err != nil { - return err +func printCallees(symbol string, rows []calleeRow, unique bool) { + for _, r := range rows { + if unique { + fmt.Printf("%s [%s]\n", r.To, r.Kind) + continue } - for _, r := range urows { - if asJSON { - unresolved = append(unresolved, r) - } else { - if unique { - fmt.Printf("~> %s [unresolved]\n", r.Expr) - } else { - fmt.Printf("%s ~> %s\t%s:%d:%d [unresolved]\n", symbol, r.Expr, r.File, r.Line, r.Col) - } - } + marker := "" + if r.Kind == "ref" { + marker = " [func-ref]" } + fmt.Printf("%s -> %s\t%s:%d:%d%s\n", symbol, r.To, r.File, r.Line, r.Col, marker) } +} - if asJSON { - payload := map[string]any{ - "callees": callees, - "callees_count": len(callees), - "unresolved": unresolved, - "unresolved_count": len(unresolved), - "env": collectEnv(dbPath), +func printCalleesUnresolved(symbol string, rows []calleesUnresolvedRow, unique bool) { + for _, r := range rows { + if unique { + fmt.Printf("~> %s [unresolved]\n", r.Expr) + continue } - if resolveNote != "" { - payload["resolved"] = resolveNote - } - if len(callees) == 0 && len(unresolved) == 0 { - if h := symbolHint(db, symbol); h != "" { - payload["hint"] = h - } - } - enc := json.NewEncoder(os.Stdout) - enc.SetEscapeHTML(false) - return enc.Encode(payload) + fmt.Printf("%s ~> %s\t%s:%d:%d [unresolved]\n", symbol, r.Expr, r.File, r.Line, r.Col) + } +} + +func emitCalleesJSON(db *sql.DB, dbPath, symbol, resolveNote string, callees []calleeRow, unresolved []calleesUnresolvedRow) error { + payload := map[string]any{ + "callees": callees, + "callees_count": len(callees), + "unresolved": unresolved, + "unresolved_count": len(unresolved), + "env": collectEnv(dbPath), } if resolveNote != "" { - fmt.Printf("# %s\n", resolveNote) + payload["resolved"] = resolveNote } - return nil + if len(callees) == 0 && len(unresolved) == 0 { + if h := symbolHint(db, symbol); h != "" { + payload["hint"] = h + } + } + enc := json.NewEncoder(os.Stdout) + enc.SetEscapeHTML(false) + return enc.Encode(payload) } diff --git a/internal/cmd/callers.go b/internal/cmd/callers.go index 10bb6f2..691e577 100644 --- a/internal/cmd/callers.go +++ b/internal/cmd/callers.go @@ -167,12 +167,7 @@ func execCallers(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy if countOnly && explain { return errors.New("--count and --explain cannot be used together") } - if depth < 1 { - depth = 1 - } - if depth > 10 { - depth = 10 - } + depth = clampCallersDepth(depth) if autoReindex { checkAndAutoReindex(db, false, false) @@ -182,61 +177,26 @@ func execCallers(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy var resolveNote string symbol, resolveNote = resolveSymbolInput(db, symbol, pkg) - var explainData *explainPayload - if explain { - explainData = &explainPayload{ - Command: "callers", - Input: rawSymbol, - ResolvedSymbol: symbol, - Resolution: resolveNote, - Filters: map[string]any{ - "pkg": pkg, - "include_unresolved": includeUnresolved, - "is_test": isTest, - "fuzzy": fuzzy, - }, - Traversal: map[string]any{ - "depth": depth, - "limit": limit, - "search_mode": map[string]any{ - "exact_first": true, - "fuzzy_fallback_when_exact_has_none": fuzzy, - }, - }, - Notes: []string{ - "pkg filter applies to caller package paths", - "depth controls BFS hops over callers", - }, - } - } + explainData := buildCallersExplain(explain, rawSymbol, symbol, resolveNote, pkg, includeUnresolved, isTest, fuzzy, depth, limit) allCallers, err := buildCallersBFS(rs, symbol, fuzzy, pkg, depth, limit, isTest) if err != nil { return err } - - unresolved := make([]callersUnresolvedRow, 0) - if includeUnresolved { - ctx := context.Background() - urows, err := rs.UnresolvedCallers(ctx, symbol, fuzzy, limit) - if err != nil { - return err - } - for _, ur := range urows { - r := callersUnresolvedRow{From: ur.From, Expr: ur.Expr, File: ur.File, Line: ur.Line, Col: ur.Col} - if asJSON { - unresolved = append(unresolved, r) - } else { - fmt.Printf("%s ~> %s\t%s:%d:%d [unresolved]\n", r.From, r.Expr, r.File, r.Line, r.Col) - } - } + unresolved, err := collectUnresolvedCallers(rs, symbol, fuzzy, includeUnresolved, limit) + if err != nil { + return err } + // Preserve original output ordering: in human mode, unresolved lines stream + // before count/callers regardless of --count or --explain. + if !asJSON { + printCallersUnresolved(unresolved) + } if countOnly { fmt.Println(len(allCallers)) return nil } - if asJSON { return formatCallersJSON(db, symbol, resolveNote, dbPath, allCallers, unresolved, depth, explainData) } @@ -245,11 +205,77 @@ func execCallers(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy fmt.Print(formatExplainText(explainData)) fmt.Println() } - if resolveNote != "" { fmt.Printf("# %s\n", resolveNote) } - for _, r := range allCallers { + printCallers(allCallers, depth) + return nil +} + +func clampCallersDepth(d int) int { + if d < 1 { + return 1 + } + if d > 10 { + return 10 + } + return d +} + +func buildCallersExplain(enabled bool, rawSymbol, resolvedSymbol, resolveNote, pkg string, includeUnresolved, isTest, fuzzy bool, depth, limit int) *explainPayload { + if !enabled { + return nil + } + return &explainPayload{ + Command: "callers", + Input: rawSymbol, + ResolvedSymbol: resolvedSymbol, + Resolution: resolveNote, + Filters: map[string]any{ + "pkg": pkg, + "include_unresolved": includeUnresolved, + "is_test": isTest, + "fuzzy": fuzzy, + }, + Traversal: map[string]any{ + "depth": depth, + "limit": limit, + "search_mode": map[string]any{ + "exact_first": true, + "fuzzy_fallback_when_exact_has_none": fuzzy, + }, + }, + Notes: []string{ + "pkg filter applies to caller package paths", + "depth controls BFS hops over callers", + }, + } +} + +func collectUnresolvedCallers(rs store.ReadStore, symbol string, fuzzy, include bool, limit int) ([]callersUnresolvedRow, error) { + if !include { + return []callersUnresolvedRow{}, nil + } + ctx := context.Background() + urows, err := rs.UnresolvedCallers(ctx, symbol, fuzzy, limit) + if err != nil { + return nil, err + } + out := make([]callersUnresolvedRow, 0, len(urows)) + for _, ur := range urows { + out = append(out, callersUnresolvedRow{From: ur.From, Expr: ur.Expr, File: ur.File, Line: ur.Line, Col: ur.Col}) + } + return out, nil +} + +func printCallersUnresolved(rows []callersUnresolvedRow) { + for _, r := range rows { + fmt.Printf("%s ~> %s\t%s:%d:%d [unresolved]\n", r.From, r.Expr, r.File, r.Line, r.Col) + } +} + +func printCallers(rows []callerRow, depth int) { + for _, r := range rows { marker := "" if r.Kind == "ref" { marker = " [func-ref]" @@ -260,5 +286,4 @@ func execCallers(rs store.ReadStore, db *sql.DB, symbol string, limit int, fuzzy } fmt.Printf("%s -> %s\t%s:%d:%d%s%s\n", r.From, r.To, r.File, r.Line, r.Col, marker, depthSuffix) } - return nil } diff --git a/internal/cmd/index.go b/internal/cmd/index.go index ca60618..0a83ecd 100644 --- a/internal/cmd/index.go +++ b/internal/cmd/index.go @@ -49,13 +49,17 @@ func newIndexCmd() *cobra.Command { return cmd } +type indexCounts struct { + symbols, calls, unresolved, typeRefs, warnings int +} + func runIndex(dbPath, root string, enableCGO, force, withTests, asJSON, benchJSON bool) error { + benchStart := time.Now() var m0 runtime.MemStats if benchJSON { runtime.GC() runtime.ReadMemStats(&m0) } - benchStart := time.Now() absRoot, err := filepath.Abs(root) if err != nil { @@ -69,10 +73,7 @@ func runIndex(dbPath, root string, enableCGO, force, withTests, asJSON, benchJSO return fmt.Errorf("no go.mod found under %s", absRoot) } - if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { - return err - } - db, err := sql.Open(sqlite.DriverName, dbPath) + db, err := openIndexDB(dbPath, force, modules) if err != nil { return err } @@ -83,114 +84,144 @@ func runIndex(dbPath, root string, enableCGO, force, withTests, asJSON, benchJSO } }() + counts := indexAllModules(db, modules, enableCGO, withTests) + + indexedCommit := captureGitCommit(absRoot) + if _, err := db.Exec(`INSERT INTO index_meta(tool_version, go_version, indexed_at, root, warnings, indexed_commit) VALUES (?, ?, ?, ?, ?, ?)`, + Version, runtime.Version(), time.Now().UTC().Format(time.RFC3339), absRoot, counts.warnings, indexedCommit); err != nil { + log.Printf("warn: index_meta insert: %v", err) + } + + log.Printf("done: %d modules, %d symbols, %d calls, %d unresolved, %d type_refs", + len(modules), counts.symbols, counts.calls, counts.unresolved, counts.typeRefs) + + if benchJSON { + if err := db.Close(); err != nil { + log.Printf("warn: db close: %v", err) + } + dbClosed = true + return emitBenchJSON(dbPath, benchStart, m0, len(modules), counts) + } + if asJSON { + return emitIndexJSON(len(modules), counts) + } + return nil +} + +func openIndexDB(dbPath string, force bool, modules []string) (*sql.DB, error) { + if err := os.MkdirAll(filepath.Dir(dbPath), 0o755); err != nil { + return nil, err + } + db, err := sql.Open(sqlite.DriverName, dbPath) + if err != nil { + return nil, err + } if force { if err := indexer.ResetSchema(db); err != nil { - return err - } - } else { - if err := indexer.EnsureSchema(db); err != nil { - return err + _ = db.Close() + return nil, err } - // Detect and purge orphaned modules (previously indexed but no longer on disk). - discovered := make(map[string]bool, len(modules)) - for _, m := range modules { - discovered[m] = true + return db, nil + } + if err := indexer.EnsureSchema(db); err != nil { + _ = db.Close() + return nil, err + } + purgeOrphanedModules(db, modules) + return db, nil +} + +// purgeOrphanedModules drops index entries for modules that were previously +// indexed but no longer exist on disk under the current root. +func purgeOrphanedModules(db *sql.DB, modules []string) { + discovered := make(map[string]bool, len(modules)) + for _, m := range modules { + discovered[m] = true + } + rows, err := db.Query(`SELECT DISTINCT module_root FROM package_meta`) + if err != nil { + return + } + defer rows.Close() + for rows.Next() { + var prev string + if rows.Scan(&prev) != nil || discovered[prev] { + continue } - rows, err := db.Query(`SELECT DISTINCT module_root FROM package_meta`) - if err == nil { - defer rows.Close() - for rows.Next() { - var prev string - if rows.Scan(&prev) == nil && !discovered[prev] { - log.Printf("purging deleted module %s ...", prev) - if err := indexer.PurgeModule(db, prev); err != nil { - log.Printf("warn: purge %s: %v", prev, err) - } - } - } + log.Printf("purging deleted module %s ...", prev) + if err := indexer.PurgeModule(db, prev); err != nil { + log.Printf("warn: purge %s: %v", prev, err) } } +} - totalSymbols := 0 - totalCalls := 0 - totalUnresolved := 0 - totalTypeRefs := 0 - totalWarnings := 0 +func indexAllModules(db *sql.DB, modules []string, enableCGO, withTests bool) indexCounts { + var c indexCounts for i, mod := range modules { log.Printf("[%d/%d] indexing %s ...", i+1, len(modules), mod) symN, callN, unresN, typeRefN, err := indexer.IndexModule(db, mod, enableCGO, withTests) - totalSymbols += symN - totalCalls += callN - totalUnresolved += unresN - totalTypeRefs += typeRefN + c.symbols += symN + c.calls += callN + c.unresolved += unresN + c.typeRefs += typeRefN if err != nil { - totalWarnings++ + c.warnings++ log.Printf("warn: module %s: %v", mod, err) } log.Printf(" done: %d symbols, %d calls, %d unresolved, %d type_refs", symN, callN, unresN, typeRefN) } + return c +} - // Capture the current git commit for stale-detection fast path. - indexedCommit := "" - if gitCmd := exec.Command("git", "rev-parse", "HEAD"); gitCmd != nil { - gitCmd.Dir = absRoot - if out, err := gitCmd.Output(); err == nil { - indexedCommit = strings.TrimSpace(string(out)) - } - } - - if _, err := db.Exec(`INSERT INTO index_meta(tool_version, go_version, indexed_at, root, warnings, indexed_commit) VALUES (?, ?, ?, ?, ?, ?)`, - Version, runtime.Version(), time.Now().UTC().Format(time.RFC3339), absRoot, totalWarnings, indexedCommit); err != nil { - log.Printf("warn: index_meta insert: %v", err) +func captureGitCommit(absRoot string) string { + gitCmd := exec.Command("git", "rev-parse", "HEAD") + gitCmd.Dir = absRoot + out, err := gitCmd.Output() + if err != nil { + return "" } + return strings.TrimSpace(string(out)) +} - log.Printf("done: %d modules, %d symbols, %d calls, %d unresolved, %d type_refs", len(modules), totalSymbols, totalCalls, totalUnresolved, totalTypeRefs) - - if benchJSON { - if err := db.Close(); err != nil { - log.Printf("warn: db close: %v", err) - } - dbClosed = true - wallNs := time.Since(benchStart).Nanoseconds() - var m1 runtime.MemStats - runtime.ReadMemStats(&m1) - dbSize := int64(-1) - if fi, err := os.Stat(dbPath); err == nil { - dbSize = fi.Size() - } - enc := json.NewEncoder(os.Stdout) - enc.SetEscapeHTML(false) - return enc.Encode(map[string]any{ - "wall_ns": wallNs, - "total_alloc_bytes": m1.TotalAlloc - m0.TotalAlloc, - "heap_alloc_bytes": m1.HeapAlloc, - "sys_bytes": m1.Sys, - "num_gc": m1.NumGC - m0.NumGC, - "pause_total_ns": m1.PauseTotalNs - m0.PauseTotalNs, - "mallocs": m1.Mallocs - m0.Mallocs, - "frees": m1.Frees - m0.Frees, - "db_path": dbPath, - "db_size_bytes": dbSize, - "modules": len(modules), - "symbols": totalSymbols, - "calls": totalCalls, - "unresolved": totalUnresolved, - "type_refs": totalTypeRefs, - "go_version": runtime.Version(), - "tool_version": Version, - }) +func emitBenchJSON(dbPath string, benchStart time.Time, m0 runtime.MemStats, modules int, c indexCounts) error { + wallNs := time.Since(benchStart).Nanoseconds() + var m1 runtime.MemStats + runtime.ReadMemStats(&m1) + dbSize := int64(-1) + if fi, err := os.Stat(dbPath); err == nil { + dbSize = fi.Size() } + enc := json.NewEncoder(os.Stdout) + enc.SetEscapeHTML(false) + return enc.Encode(map[string]any{ + "wall_ns": wallNs, + "total_alloc_bytes": m1.TotalAlloc - m0.TotalAlloc, + "heap_alloc_bytes": m1.HeapAlloc, + "sys_bytes": m1.Sys, + "num_gc": m1.NumGC - m0.NumGC, + "pause_total_ns": m1.PauseTotalNs - m0.PauseTotalNs, + "mallocs": m1.Mallocs - m0.Mallocs, + "frees": m1.Frees - m0.Frees, + "db_path": dbPath, + "db_size_bytes": dbSize, + "modules": modules, + "symbols": c.symbols, + "calls": c.calls, + "unresolved": c.unresolved, + "type_refs": c.typeRefs, + "go_version": runtime.Version(), + "tool_version": Version, + }) +} - if asJSON { - enc := json.NewEncoder(os.Stdout) - enc.SetEscapeHTML(false) - return enc.Encode(map[string]any{ - "indexed": len(modules), - "symbols": totalSymbols, - "calls": totalCalls, - "unresolved": totalUnresolved, - "type_refs": totalTypeRefs, - }) - } - return nil +func emitIndexJSON(modules int, c indexCounts) error { + enc := json.NewEncoder(os.Stdout) + enc.SetEscapeHTML(false) + return enc.Encode(map[string]any{ + "indexed": modules, + "symbols": c.symbols, + "calls": c.calls, + "unresolved": c.unresolved, + "type_refs": c.typeRefs, + }) } diff --git a/internal/cmd/references.go b/internal/cmd/references.go index 575bf87..5b77a51 100644 --- a/internal/cmd/references.go +++ b/internal/cmd/references.go @@ -51,6 +51,19 @@ func newReferencesCmd() *cobra.Command { return cmd } +type refRow struct { + From string `json:"from"` + FromName string `json:"from_name"` + To string `json:"to"` + ToName string `json:"to_name"` + RefKind string `json:"ref_kind"` + File string `json:"file"` + Line int `json:"line"` + Col int `json:"col"` + Expr string `json:"expr"` + PackagePath string `json:"package_path"` +} + func execReferences(rs store.ReadStore, dbPath, symbol, pkg, refKind, from string, limit int, countOnly, asJSON bool) error { if strings.TrimSpace(symbol) == "" { return errors.New("--symbol is required") @@ -60,24 +73,7 @@ func execReferences(rs store.ReadStore, dbPath, symbol, pkg, refKind, from strin } ctx := context.Background() - - // Use the store's ResolveSymbolName to resolve a short name to a fqname. - resolvedSymbol := symbol - resolvedNote := "" - if !strings.Contains(symbol, "/") { - names, err := rs.ResolveSymbolName(ctx, symbol, pkg) - if err == nil { - switch len(names) { - case 1: - resolvedSymbol = names[0] - resolvedNote = "resolved short name '" + symbol + "' to '" + names[0] + "'" - case 0: - // leave as-is - default: - resolvedNote = "ambiguous short name '" + symbol + "' — use exact fqname or --pkg to disambiguate" - } - } - } + resolvedSymbol, resolvedNote := resolveTypeRef(ctx, rs, symbol, pkg) opts := store.ReferencesOpts{ Symbol: resolvedSymbol, @@ -101,19 +97,6 @@ func execReferences(rs store.ReadStore, dbPath, symbol, pkg, refKind, from strin return err } - type refRow struct { - From string `json:"from"` - FromName string `json:"from_name"` - To string `json:"to"` - ToName string `json:"to_name"` - RefKind string `json:"ref_kind"` - File string `json:"file"` - Line int `json:"line"` - Col int `json:"col"` - Expr string `json:"expr"` - PackagePath string `json:"package_path"` - } - results := make([]refRow, 0, len(result.Refs)) for _, r := range result.Refs { results = append(results, refRow{ @@ -131,28 +114,53 @@ func execReferences(rs store.ReadStore, dbPath, symbol, pkg, refKind, from strin } if asJSON { - out := map[string]any{ - "references": results, - "count": len(results), - "total_matched": result.TotalMatched, - "truncated": result.TotalMatched > limit, - "env": collectEnv(dbPath), - } - if resolvedNote != "" { - out["resolved"] = resolvedNote - } - if len(results) == 0 { - hints, _ := rs.SymbolHint(ctx, resolvedSymbol) - if len(hints) > 0 { - out["hint"] = "Exact fqname mismatch. Similar: " + strings.Join(hints, " | ") + ". Use exact fqname or --fuzzy." - } + return emitReferencesJSON(ctx, rs, dbPath, resolvedSymbol, resolvedNote, results, result.TotalMatched, limit) + } + printReferences(ctx, rs, resolvedSymbol, resolvedNote, results, result.TotalMatched, limit) + return nil +} + +func resolveTypeRef(ctx context.Context, rs store.ReadStore, symbol, pkg string) (string, string) { + if strings.Contains(symbol, "/") { + return symbol, "" + } + names, err := rs.ResolveSymbolName(ctx, symbol, pkg) + if err != nil { + return symbol, "" + } + switch len(names) { + case 1: + return names[0], "resolved short name '" + symbol + "' to '" + names[0] + "'" + case 0: + return symbol, "" + default: + return symbol, "ambiguous short name '" + symbol + "' — use exact fqname or --pkg to disambiguate" + } +} + +func emitReferencesJSON(ctx context.Context, rs store.ReadStore, dbPath, resolvedSymbol, resolvedNote string, results []refRow, totalMatched, limit int) error { + out := map[string]any{ + "references": results, + "count": len(results), + "total_matched": totalMatched, + "truncated": totalMatched > limit, + "env": collectEnv(dbPath), + } + if resolvedNote != "" { + out["resolved"] = resolvedNote + } + if len(results) == 0 { + hints, _ := rs.SymbolHint(ctx, resolvedSymbol) + if len(hints) > 0 { + out["hint"] = "Exact fqname mismatch. Similar: " + strings.Join(hints, " | ") + ". Use exact fqname or --fuzzy." } - enc := json.NewEncoder(os.Stdout) - enc.SetEscapeHTML(false) - return enc.Encode(out) } + enc := json.NewEncoder(os.Stdout) + enc.SetEscapeHTML(false) + return enc.Encode(out) +} - // Text output +func printReferences(ctx context.Context, rs store.ReadStore, resolvedSymbol, resolvedNote string, results []refRow, totalMatched, limit int) { if resolvedNote != "" { fmt.Printf("note: %s\n\n", resolvedNote) } @@ -165,12 +173,11 @@ func execReferences(rs store.ReadStore, dbPath, symbol, pkg, refKind, from strin if len(hints) > 0 { fmt.Printf("hint: Exact fqname mismatch. Similar: %s. Use exact fqname.\n", strings.Join(hints, " | ")) } - } else { - fmt.Printf("\n%d reference(s)", len(results)) - if result.TotalMatched > limit { - fmt.Printf(" (truncated; %d total)", result.TotalMatched) - } - fmt.Println() + return } - return nil + fmt.Printf("\n%d reference(s)", len(results)) + if totalMatched > limit { + fmt.Printf(" (truncated; %d total)", totalMatched) + } + fmt.Println() } diff --git a/store/sqlite/read.go b/store/sqlite/read.go index 3c22934..55b7c25 100644 --- a/store/sqlite/read.go +++ b/store/sqlite/read.go @@ -139,9 +139,47 @@ func buildFindWhere(opts store.FindOpts) (string, []any) { // ---- DefSymbol ---- func (s *SQLiteStore) DefSymbol(ctx context.Context, name, pkg string) ([]store.SymbolRow, error) { - const cols = `SELECT fqname, kind, file_path, line, col, signature, package_path FROM symbols` + scan := s.symbolScanner(ctx) - scanRows := func(q string, args ...any) ([]store.SymbolRow, error) { + results, err := scanByFqname(scan, name, pkg) + if err != nil { + return nil, fmt.Errorf("DefSymbol fqname lookup: %w", err) + } + if len(results) > 0 { + return results, nil + } + if base := stripInstantiationArgs(name); base != name { + results, err = scanByFqname(scan, base, pkg) + if err != nil { + return nil, fmt.Errorf("DefSymbol normalized fqname lookup: %w", err) + } + if len(results) > 0 { + return results, nil + } + } + + results, err = scanByName(scan, name, pkg) + if err != nil { + return nil, fmt.Errorf("DefSymbol name lookup: %w", err) + } + if len(results) > 0 { + return results, nil + } + if base := stripInstantiationArgs(name); base != name { + results, err = scanByName(scan, base, pkg) + if err != nil { + return nil, fmt.Errorf("DefSymbol normalized name lookup: %w", err) + } + } + return results, nil +} + +const defSymbolCols = `SELECT fqname, kind, file_path, line, col, signature, package_path FROM symbols` + +type symbolScanFn func(q string, args ...any) ([]store.SymbolRow, error) + +func (s *SQLiteStore) symbolScanner(ctx context.Context) symbolScanFn { + return func(q string, args ...any) ([]store.SymbolRow, error) { rows, err := s.db.QueryContext(ctx, q, args...) if err != nil { return nil, err @@ -158,62 +196,27 @@ func (s *SQLiteStore) DefSymbol(ctx context.Context, name, pkg string) ([]store. } return results, rows.Err() } +} - // 1. Exact fqname match. - q := cols + ` WHERE fqname = ?` - qargs := []any{name} +func scanByFqname(scan symbolScanFn, name, pkg string) ([]store.SymbolRow, error) { + q := defSymbolCols + ` WHERE fqname = ?` + args := []any{name} if pkg != "" { q += ` AND package_path LIKE ?` - qargs = append(qargs, pkg+"%") - } - results, err := scanRows(q, qargs...) - if err != nil { - return nil, fmt.Errorf("DefSymbol fqname lookup: %w", err) - } - if len(results) == 0 { - baseName := stripInstantiationArgs(name) - if baseName != name { - qargs = []any{baseName} - if pkg != "" { - qargs = append(qargs, pkg+"%") - } - results, err = scanRows(q, qargs...) - if err != nil { - return nil, fmt.Errorf("DefSymbol normalized fqname lookup: %w", err) - } - } + args = append(args, pkg+"%") } + return scan(q, args...) +} - // 2. Exact name match (may be ambiguous). - if len(results) == 0 { - q = cols + ` WHERE name = ?` - qargs = []any{name} - if pkg != "" { - q += ` AND package_path LIKE ?` - qargs = append(qargs, pkg+"%") - } - q += ` ORDER BY fqname LIMIT 20` - results, err = scanRows(q, qargs...) - if err != nil { - return nil, fmt.Errorf("DefSymbol name lookup: %w", err) - } - if len(results) == 0 { - baseName := stripInstantiationArgs(name) - if baseName != name { - qargs = []any{baseName} - if pkg != "" { - qargs = append(qargs, pkg+"%") - } - qargs = append(qargs, 20) - results, err = scanRows(q, qargs...) - if err != nil { - return nil, fmt.Errorf("DefSymbol normalized name lookup: %w", err) - } - } - } +func scanByName(scan symbolScanFn, name, pkg string) ([]store.SymbolRow, error) { + q := defSymbolCols + ` WHERE name = ?` + args := []any{name} + if pkg != "" { + q += ` AND package_path LIKE ?` + args = append(args, pkg+"%") } - - return results, nil + q += ` ORDER BY fqname LIMIT 20` + return scan(q, args...) } // ---- ListPackages ----