Skip to content

Diversity Sampler (Farthest-Point / Max-Min via FAISS) #16

@Gautam-Rajeev

Description

@Gautam-Rajeev

Goal

Select the most diverse subset from raw_samples.jsonl using max-min farthest-point sampling over sentence embeddings indexed in FAISS. The algorithm greedily picks the next sample that is as far as possible from all already-selected samples, ensuring the chosen set spans the full semantic space of the corpus rather than clustering around frequent topics.

Hard per-dimension quotas (topic, lang, query_type) act as a secondary constraint ; the farthest-point loop runs within each quota bucket so diversity is enforced both semantically and categorically.


Algorithm : Max-Min Farthest-Point Sampling

Given a set of points X with embeddings {e_1 … e_n} and a target size k:

  1. Pick a random seed point s_0 as the initial selected set S = {s_0}.
  2. For each remaining point x_i, compute d(x_i, S) = min_{s ∈ S} dist(x_i, s) — the distance to the nearest already-selected point.
  3. Select x* = argmax_i d(x_i, S) — the point farthest from everything already chosen.
  4. Add x* to S. Update distances (only need to check new point).
  5. Repeat until |S| = k.

This guarantees S is a well-spread sample: no two selected points are unnecessarily close, and underrepresented regions of the semantic space are actively pulled in.

Why FAISS:

Anything works here, just selecting FAISS here because I'm familiar with it, has enough indexing options and is fast


Inputs

Input Description
raw_samples.jsonl Output of Ticket 1. Each record has prompt, response, topic, lang, query_type.
sampler_config.yaml Quotas per dimension + target sample size + embedding model ID.

Deliverables

diversity_sampler.py

Main class. Exposes:

DiversitySampler(config_path: str)
    .embed(samples: List[dict]) -> np.ndarray          # encode all prompts+responses
    .build_index(embeddings: np.ndarray) -> faiss.Index
    .farthest_point_sample(
          embeddings, indices, k
      ) -> List[int]                                    # core max-min loop
    .sample(samples: List[dict]) -> List[dict]          # full pipeline incl. quota bucketing

sampler_config.yaml

embedding_model: "paraphrase-multilingual-MiniLM-L12-v2"
target_total: 10000          # total samples after sampling

quotas:
  topic:
    crop_advisory:   0.25
    market_price:    0.20
    gov_schemes:     0.15
    weather:         0.20
    pest_disease:    0.20

  lang:
    en:              0.25
    hi:              0.45
    hi_en_mix:       0.30

  query_type:
    simple_qa:            0.40
    tool_call_single:     0.30
    tool_call_multi_hop:  0.20
    adversarial:          0.10

farthest_point:
  seed_strategy: "random"    # "random" | "highest_score" | "centroid_farthest"
  distance_metric: "cosine"  # vectors L2-normalised; inner product == cosine sim
  min_distance_threshold: 0.05   # skip if farthest candidate is < this — corpus exhausted

run_sampler.py

CLI wrapper:

python run_sampler.py \
  --input  data/raw_samples.jsonl \
  --output data/sampled_data.jsonl \
  --config sampler_config.yaml \
  --log    logs/sampler_run.json

Implementation Notes

Bucketing Strategy

Do not run one global farthest-point pass and then apply quotas as a post-filter — that risks quota buckets being underfilled because the global pass never reaches underrepresented areas.

Instead:

  1. Split raw_samples into cells defined by (topic, lang, query_type).
  2. For each cell, compute how many samples it should contribute: cell_target = target_total * p(topic) * p(lang) * p(query_type) (joint quota, assuming independence). Round to int; redistribute remainders to largest cells.
  3. Run the farthest-point loop independently within each cell for cell_target samples.
  4. Merge all selected samples and shuffle before writing.

This way the semantic diversity guarantee holds within each category, not just globally.

Embedding

We can decide what strategy works best here, but for now let's use the below

  • Encode prompt + " " + response (concatenated) so both query style and answer style are captured.

FAISS Index per Cell

  • Decide which index to use. Could start with faiss.IndexFlatIP (exact inner product).
  • The "selected set" index is rebuilt incrementally: start empty, add one vector per iteration.

Distance Update (efficient)

Think of how to reuse last iterations distance to the next iterations distance to simplify the runs

Edge Cases

Case Handling
Cell has fewer samples than cell_target Take all samples in cell; log as skipped_exhausted_buckets
min_distance_threshold hit mid-loop Stop early for that cell; log count actually selected
Duplicate prompts in raw data Deduplicate by prompt hash before sampling
Missing topic/lang field Assign to "unknown" bucket; do not crash

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions