Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions internal/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func Run(parent context.Context, cfg *Config, version string, emit func(Event))
if cfg.RecallSample > totalNeeded {
totalNeeded = cfg.RecallSample
}
queries, err := sampleQueryVectors(ctx, pool, env, totalNeeded)
queries, selfIDs, err := sampleQueryVectors(ctx, pool, env, totalNeeded)
if err != nil {
return nil, err
}
Expand All @@ -128,10 +128,12 @@ func Run(parent context.Context, cfg *Config, version string, emit func(Event))

// Recall (+ ef_search sweep).
recallQ := queries
recallIDs := selfIDs
if len(recallQ) > cfg.RecallSample {
recallQ = recallQ[:cfg.RecallSample]
recallIDs = recallIDs[:cfg.RecallSample]
}
rec, err := runRecall(ctx, pool, env, cfg, recallQ, emit)
rec, err := runRecall(ctx, pool, env, cfg, recallQ, recallIDs, emit)
if err != nil {
return nil, err
}
Expand Down
56 changes: 42 additions & 14 deletions internal/engine/recall.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func runRecall(
env *Env,
cfg *Config,
queries [][]float32,
selfIDs []string,
emit func(Event),
) (RecallResult, error) {
if env.IndexType != "hnsw" && env.IndexType != "ivfflat" {
Expand All @@ -44,10 +45,15 @@ func runRecall(

// Determine row identity: prefer ctid (always unique, no schema assumptions).
idExpr := "ctid"
// Fetch k+1: query vectors are sampled from the table, so the row itself is
// the distance-0 nearest neighbour. We drop that self-match (by ctid) from
// both the index result and the ground truth so recall@k measures the k real
// neighbours, not a guaranteed free hit.
fetchK := cfg.K + 1
annSQL := fmt.Sprintf("SELECT %s::text FROM %s ORDER BY %s %s $1::vector LIMIT %d",
idExpr, tbl, col, op, cfg.K)
idExpr, tbl, col, op, fetchK)
exactSQL := fmt.Sprintf("SELECT %s::text FROM %s ORDER BY %s %s $1::vector LIMIT %d",
idExpr, tbl, col, op, cfg.K)
idExpr, tbl, col, op, fetchK)

// Compute ground truth once per query.
groundTruth := make([][]string, len(queries))
Expand Down Expand Up @@ -93,7 +99,7 @@ func runRecall(
conn.Release()
return RecallResult{}, err
}
groundTruth[i] = ids
groundTruth[i] = dropSelf(ids, selfIDs[i], cfg.K)
if i%10 == 0 || i == len(queries)-1 {
emit(Progress{Phase: PhaseRecall, Done: i + 1, Total: len(queries),
ExtraLabel: "ground truth"})
Expand All @@ -113,7 +119,7 @@ func runRecall(
if ctx.Err() != nil {
return out, ctx.Err()
}
pt, err := measureRecallAt(ctx, pool, env, annSQL, queries, groundTruth, ef, emit)
pt, err := measureRecallAt(ctx, pool, env, cfg, annSQL, queries, selfIDs, groundTruth, ef, emit)
if err != nil {
return out, err
}
Expand All @@ -138,8 +144,10 @@ func measureRecallAt(
ctx context.Context,
pool *pgxpool.Pool,
env *Env,
cfg *Config,
annSQL string,
queries [][]float32,
selfIDs []string,
groundTruth [][]string,
efSearch int,
emit func(Event),
Expand All @@ -150,8 +158,17 @@ func measureRecallAt(
}
defer conn.Release()

// Wrap the whole level in one transaction and use SET LOCAL so ef_search is
// honoured even when the target is reached through a transaction pooler
// (e.g. PgBouncer/Supabase pooler), where a session-level SET on autocommit
// queries can land on a different backend and be silently ignored.
tx, err := conn.Begin(ctx)
if err != nil {
return RecallPoint{}, scrub(err)
}
defer tx.Rollback(ctx) //nolint:errcheck — read-only; rollback is the cleanup
if efSearch > 0 && env.IndexType == "hnsw" {
if _, err := conn.Exec(ctx, fmt.Sprintf("SET hnsw.ef_search = %d", efSearch)); err != nil {
if _, err := tx.Exec(ctx, fmt.Sprintf("SET LOCAL hnsw.ef_search = %d", efSearch)); err != nil {
return RecallPoint{}, scrub(err)
}
}
Expand All @@ -164,20 +181,17 @@ func measureRecallAt(
return RecallPoint{}, ctx.Err()
}
t0 := time.Now()
ids, err := scanIDsConn(ctx, conn, annSQL, q)
ids, err := scanIDs(ctx, tx, annSQL, q)
if err != nil {
return RecallPoint{}, err
}
durs = append(durs, float64(time.Since(t0).Microseconds())/1000.0)
hits += jaccardIntersection(ids, groundTruth[i])
hits += jaccardIntersection(dropSelf(ids, selfIDs[i], cfg.K), groundTruth[i])
if i%10 == 0 || i == len(queries)-1 {
emit(Progress{Phase: PhaseRecall, Done: i + 1, Total: len(queries),
ExtraLabel: fmt.Sprintf("ef_search=%d", efSearch)})
}
}
if env.IndexType == "hnsw" {
_, _ = conn.Exec(ctx, "RESET hnsw.ef_search")
}
wall := time.Since(t0Wall).Seconds()
sort.Float64s(durs)
return RecallPoint{
Expand All @@ -188,6 +202,24 @@ func measureRecallAt(
}, nil
}

// dropSelf removes the query row's own ctid (the distance-0 self-match) from a
// neighbour list and trims it to k, so recall measures the k real neighbours.
func dropSelf(ids []string, self string, k int) []string {
out := make([]string, 0, len(ids))
removed := false
for _, id := range ids {
if !removed && id == self {
removed = true
continue
}
out = append(out, id)
}
if len(out) > k {
out = out[:k]
}
return out
}

// jaccardIntersection returns |a ∩ b| / |b| — the per-query recall ratio.
func jaccardIntersection(a, b []string) float64 {
set := make(map[string]struct{}, len(b))
Expand Down Expand Up @@ -227,10 +259,6 @@ func scanIDs(ctx context.Context, c conniface, sql string, v []float32) ([]strin
return ids, rows.Err()
}

func scanIDsConn(ctx context.Context, conn *pgxpool.Conn, sql string, v []float32) ([]string, error) {
return scanIDs(ctx, conn.Conn(), sql, v)
}

func verifySeqScan(ctx context.Context, tx pgx.Tx, sql string, v []float32) (bool, string) {
rows, err := tx.Query(ctx, "EXPLAIN "+sql, VectorLiteral(v))
if err != nil {
Expand Down
30 changes: 19 additions & 11 deletions internal/engine/workload.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ import (
// using TABLESAMPLE if the table is large, otherwise ORDER BY random().
// The sampled vectors are used both as query points and (for recall sampling)
// as the basis for exact-KNN ground truth.
func sampleQueryVectors(ctx context.Context, pool *pgxpool.Pool, env *Env, n int) ([][]float32, error) {
//
// It also returns each vector's source ctid (selfIDs, aligned with the vectors)
// so recall can exclude the trivial self-match: a query vector taken straight
// from the table has itself as its own nearest neighbour (distance 0), which
// would otherwise be a guaranteed "free" hit in both the index and the ground
// truth and inflate recall@k.
func sampleQueryVectors(ctx context.Context, pool *pgxpool.Pool, env *Env, n int) (vecs [][]float32, selfIDs []string, err error) {
tbl := quoteIdent(env.Schema, env.Table)
col := quoteIdent(env.Column)

Expand All @@ -27,31 +33,33 @@ func sampleQueryVectors(ctx context.Context, pool *pgxpool.Pool, env *Env, n int
pct = 100
}
q = fmt.Sprintf(
"SELECT %s::text FROM %s TABLESAMPLE SYSTEM (%.4f) WHERE %s IS NOT NULL LIMIT %d",
"SELECT ctid::text, %s::text FROM %s TABLESAMPLE SYSTEM (%.4f) WHERE %s IS NOT NULL LIMIT %d",
col, tbl, pct, col, n)
} else {
q = fmt.Sprintf(
"SELECT %s::text FROM %s WHERE %s IS NOT NULL ORDER BY random() LIMIT %d",
"SELECT ctid::text, %s::text FROM %s WHERE %s IS NOT NULL ORDER BY random() LIMIT %d",
col, tbl, col, n)
}

rows, err := pool.Query(ctx, q)
if err != nil {
return nil, fmt.Errorf("sample vectors: %w", scrub(err))
return nil, nil, fmt.Errorf("sample vectors: %w", scrub(err))
}
defer rows.Close()

out := make([][]float32, 0, n)
vecs = make([][]float32, 0, n)
selfIDs = make([]string, 0, n)
for rows.Next() {
var s string
if err := rows.Scan(&s); err != nil {
return nil, scrub(err)
var id, s string
if err := rows.Scan(&id, &s); err != nil {
return nil, nil, scrub(err)
}
v, err := ParseVectorLiteral(s)
if err != nil {
return nil, err
return nil, nil, err
}
out = append(out, v)
vecs = append(vecs, v)
selfIDs = append(selfIDs, id)
}
return out, rows.Err()
return vecs, selfIDs, rows.Err()
}
Loading