diff --git a/internal/engine/engine.go b/internal/engine/engine.go index ce94c14..45b35f9 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -101,7 +101,7 @@ func Run(parent context.Context, cfg *Config, version string, emit func(Event)) if cfg.RecallSample > totalNeeded { totalNeeded = cfg.RecallSample } - queries, err := sampleQueryVectors(ctx, pool, env, totalNeeded) + queries, selfIDs, err := sampleQueryVectors(ctx, pool, env, totalNeeded) if err != nil { return nil, err } @@ -128,10 +128,12 @@ func Run(parent context.Context, cfg *Config, version string, emit func(Event)) // Recall (+ ef_search sweep). recallQ := queries + recallIDs := selfIDs if len(recallQ) > cfg.RecallSample { recallQ = recallQ[:cfg.RecallSample] + recallIDs = recallIDs[:cfg.RecallSample] } - rec, err := runRecall(ctx, pool, env, cfg, recallQ, emit) + rec, err := runRecall(ctx, pool, env, cfg, recallQ, recallIDs, emit) if err != nil { return nil, err } diff --git a/internal/engine/recall.go b/internal/engine/recall.go index df2f352..288c82e 100644 --- a/internal/engine/recall.go +++ b/internal/engine/recall.go @@ -26,6 +26,7 @@ func runRecall( env *Env, cfg *Config, queries [][]float32, + selfIDs []string, emit func(Event), ) (RecallResult, error) { if env.IndexType != "hnsw" && env.IndexType != "ivfflat" { @@ -44,10 +45,15 @@ func runRecall( // Determine row identity: prefer ctid (always unique, no schema assumptions). idExpr := "ctid" + // Fetch k+1: query vectors are sampled from the table, so the row itself is + // the distance-0 nearest neighbour. We drop that self-match (by ctid) from + // both the index result and the ground truth so recall@k measures the k real + // neighbours, not a guaranteed free hit. + fetchK := cfg.K + 1 annSQL := fmt.Sprintf("SELECT %s::text FROM %s ORDER BY %s %s $1::vector LIMIT %d", - idExpr, tbl, col, op, cfg.K) + idExpr, tbl, col, op, fetchK) exactSQL := fmt.Sprintf("SELECT %s::text FROM %s ORDER BY %s %s $1::vector LIMIT %d", - idExpr, tbl, col, op, cfg.K) + idExpr, tbl, col, op, fetchK) // Compute ground truth once per query. groundTruth := make([][]string, len(queries)) @@ -93,7 +99,7 @@ func runRecall( conn.Release() return RecallResult{}, err } - groundTruth[i] = ids + groundTruth[i] = dropSelf(ids, selfIDs[i], cfg.K) if i%10 == 0 || i == len(queries)-1 { emit(Progress{Phase: PhaseRecall, Done: i + 1, Total: len(queries), ExtraLabel: "ground truth"}) @@ -113,7 +119,7 @@ func runRecall( if ctx.Err() != nil { return out, ctx.Err() } - pt, err := measureRecallAt(ctx, pool, env, annSQL, queries, groundTruth, ef, emit) + pt, err := measureRecallAt(ctx, pool, env, cfg, annSQL, queries, selfIDs, groundTruth, ef, emit) if err != nil { return out, err } @@ -138,8 +144,10 @@ func measureRecallAt( ctx context.Context, pool *pgxpool.Pool, env *Env, + cfg *Config, annSQL string, queries [][]float32, + selfIDs []string, groundTruth [][]string, efSearch int, emit func(Event), @@ -150,8 +158,17 @@ func measureRecallAt( } defer conn.Release() + // Wrap the whole level in one transaction and use SET LOCAL so ef_search is + // honoured even when the target is reached through a transaction pooler + // (e.g. PgBouncer/Supabase pooler), where a session-level SET on autocommit + // queries can land on a different backend and be silently ignored. + tx, err := conn.Begin(ctx) + if err != nil { + return RecallPoint{}, scrub(err) + } + defer tx.Rollback(ctx) //nolint:errcheck — read-only; rollback is the cleanup if efSearch > 0 && env.IndexType == "hnsw" { - if _, err := conn.Exec(ctx, fmt.Sprintf("SET hnsw.ef_search = %d", efSearch)); err != nil { + if _, err := tx.Exec(ctx, fmt.Sprintf("SET LOCAL hnsw.ef_search = %d", efSearch)); err != nil { return RecallPoint{}, scrub(err) } } @@ -164,20 +181,17 @@ func measureRecallAt( return RecallPoint{}, ctx.Err() } t0 := time.Now() - ids, err := scanIDsConn(ctx, conn, annSQL, q) + ids, err := scanIDs(ctx, tx, annSQL, q) if err != nil { return RecallPoint{}, err } durs = append(durs, float64(time.Since(t0).Microseconds())/1000.0) - hits += jaccardIntersection(ids, groundTruth[i]) + hits += jaccardIntersection(dropSelf(ids, selfIDs[i], cfg.K), groundTruth[i]) if i%10 == 0 || i == len(queries)-1 { emit(Progress{Phase: PhaseRecall, Done: i + 1, Total: len(queries), ExtraLabel: fmt.Sprintf("ef_search=%d", efSearch)}) } } - if env.IndexType == "hnsw" { - _, _ = conn.Exec(ctx, "RESET hnsw.ef_search") - } wall := time.Since(t0Wall).Seconds() sort.Float64s(durs) return RecallPoint{ @@ -188,6 +202,24 @@ func measureRecallAt( }, nil } +// dropSelf removes the query row's own ctid (the distance-0 self-match) from a +// neighbour list and trims it to k, so recall measures the k real neighbours. +func dropSelf(ids []string, self string, k int) []string { + out := make([]string, 0, len(ids)) + removed := false + for _, id := range ids { + if !removed && id == self { + removed = true + continue + } + out = append(out, id) + } + if len(out) > k { + out = out[:k] + } + return out +} + // jaccardIntersection returns |a ∩ b| / |b| — the per-query recall ratio. func jaccardIntersection(a, b []string) float64 { set := make(map[string]struct{}, len(b)) @@ -227,10 +259,6 @@ func scanIDs(ctx context.Context, c conniface, sql string, v []float32) ([]strin return ids, rows.Err() } -func scanIDsConn(ctx context.Context, conn *pgxpool.Conn, sql string, v []float32) ([]string, error) { - return scanIDs(ctx, conn.Conn(), sql, v) -} - func verifySeqScan(ctx context.Context, tx pgx.Tx, sql string, v []float32) (bool, string) { rows, err := tx.Query(ctx, "EXPLAIN "+sql, VectorLiteral(v)) if err != nil { diff --git a/internal/engine/workload.go b/internal/engine/workload.go index 62e59a0..d5749fc 100644 --- a/internal/engine/workload.go +++ b/internal/engine/workload.go @@ -11,7 +11,13 @@ import ( // using TABLESAMPLE if the table is large, otherwise ORDER BY random(). // The sampled vectors are used both as query points and (for recall sampling) // as the basis for exact-KNN ground truth. -func sampleQueryVectors(ctx context.Context, pool *pgxpool.Pool, env *Env, n int) ([][]float32, error) { +// +// It also returns each vector's source ctid (selfIDs, aligned with the vectors) +// so recall can exclude the trivial self-match: a query vector taken straight +// from the table has itself as its own nearest neighbour (distance 0), which +// would otherwise be a guaranteed "free" hit in both the index and the ground +// truth and inflate recall@k. +func sampleQueryVectors(ctx context.Context, pool *pgxpool.Pool, env *Env, n int) (vecs [][]float32, selfIDs []string, err error) { tbl := quoteIdent(env.Schema, env.Table) col := quoteIdent(env.Column) @@ -27,31 +33,33 @@ func sampleQueryVectors(ctx context.Context, pool *pgxpool.Pool, env *Env, n int pct = 100 } q = fmt.Sprintf( - "SELECT %s::text FROM %s TABLESAMPLE SYSTEM (%.4f) WHERE %s IS NOT NULL LIMIT %d", + "SELECT ctid::text, %s::text FROM %s TABLESAMPLE SYSTEM (%.4f) WHERE %s IS NOT NULL LIMIT %d", col, tbl, pct, col, n) } else { q = fmt.Sprintf( - "SELECT %s::text FROM %s WHERE %s IS NOT NULL ORDER BY random() LIMIT %d", + "SELECT ctid::text, %s::text FROM %s WHERE %s IS NOT NULL ORDER BY random() LIMIT %d", col, tbl, col, n) } rows, err := pool.Query(ctx, q) if err != nil { - return nil, fmt.Errorf("sample vectors: %w", scrub(err)) + return nil, nil, fmt.Errorf("sample vectors: %w", scrub(err)) } defer rows.Close() - out := make([][]float32, 0, n) + vecs = make([][]float32, 0, n) + selfIDs = make([]string, 0, n) for rows.Next() { - var s string - if err := rows.Scan(&s); err != nil { - return nil, scrub(err) + var id, s string + if err := rows.Scan(&id, &s); err != nil { + return nil, nil, scrub(err) } v, err := ParseVectorLiteral(s) if err != nil { - return nil, err + return nil, nil, err } - out = append(out, v) + vecs = append(vecs, v) + selfIDs = append(selfIDs, id) } - return out, rows.Err() + return vecs, selfIDs, rows.Err() }