Skip to content

Commit 7480c9c

Browse files
authored
perf: fallback to brute force FTS if filters matching fewer rows (lance-format#4551)
if the filters match only a few rows, it may cause the WAND fails to filter out docs. So we can evaluate only the matched rows, which would be much faster than running WAND first, this means to decompress at most `num_rows_matched * num_tokens` blocks --------- Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent 746d2dd commit 7480c9c

3 files changed

Lines changed: 180 additions & 21 deletions

File tree

python/python/tests/test_scalar_index.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,38 @@ def test_fts_score(tmp_path):
721721
assert results["id"].to_pylist() == [3, 2, 1]
722722

723723

724+
def test_fts_with_filter(tmp_path):
725+
data = pa.table(
726+
{
727+
"id": [1, 2, 3],
728+
"text": ["lance database test", "full text search", "lance search text"],
729+
}
730+
)
731+
ds = lance.write_dataset(data, tmp_path)
732+
ds.create_scalar_index("id", "BTREE")
733+
ds.create_scalar_index("text", "INVERTED")
734+
735+
results = ds.to_table(full_text_query="lance search text")
736+
assert results.num_rows == 3
737+
assert results["id"].to_pylist() == [3, 2, 1]
738+
739+
score_id1 = results.column("_score")[2].as_py()
740+
741+
results = ds.to_table(
742+
full_text_query="lance search text",
743+
filter="id <= 1",
744+
prefilter=True,
745+
)
746+
assert results.num_rows == 1
747+
assert results["id"].to_pylist() == [1]
748+
assert results.column("_score")[0].as_py() == score_id1
749+
750+
plan = ds.scanner(
751+
full_text_query="lance search text", filter="id <= 1", prefilter=True
752+
).analyze_plan()
753+
assert "index_comparisons=1" in plan
754+
755+
724756
def test_fts_on_list(tmp_path):
725757
data = pa.table(
726758
{

rust/lance-index/src/scalar/inverted/index.rs

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use std::sync::Arc;
66
use std::{
77
cmp::{min, Reverse},
88
collections::BinaryHeap,
9-
ops::RangeInclusive,
109
};
1110
use std::{
1211
collections::{HashMap, HashSet},
@@ -159,6 +158,7 @@ impl InvertedIndex {
159158
return Ok((Vec::new(), Vec::new()));
160159
}
161160
let mask = prefilter.mask();
161+
162162
let mut candidates = BinaryHeap::new();
163163
let parts = self
164164
.partitions
@@ -390,6 +390,7 @@ impl ScalarIndex for InvertedIndex {
390390
.buffer_unordered(store.io_parallelism())
391391
.try_collect::<Vec<_>>()
392392
.await?;
393+
393394
let tokenizer = params.build()?;
394395
Ok(Arc::new(Self {
395396
params,
@@ -1738,6 +1739,9 @@ impl Ord for RawDocInfo {
17381739
pub struct DocSet {
17391740
row_ids: Vec<u64>,
17401741
num_tokens: Vec<u32>,
1742+
// (row_id, doc_id) pairs sorted by row_id
1743+
inv: Vec<(u64, u32)>,
1744+
17411745
total_tokens: u64,
17421746
}
17431747

@@ -1759,8 +1763,19 @@ impl DocSet {
17591763
self.row_ids[doc_id as usize]
17601764
}
17611765

1762-
pub fn row_range(&self) -> RangeInclusive<u64> {
1763-
self.row_ids[0]..=self.row_ids[self.len() - 1]
1766+
pub fn doc_id(&self, row_id: u64) -> Option<u64> {
1767+
if self.inv.is_empty() {
1768+
// in legacy format, the row id is doc id
1769+
match self.row_ids.binary_search(&row_id) {
1770+
Ok(_) => Some(row_id),
1771+
Err(_) => None,
1772+
}
1773+
} else {
1774+
match self.inv.binary_search_by_key(&row_id, |x| x.0) {
1775+
Ok(idx) => Some(self.inv[idx].1 as u64),
1776+
Err(_) => None,
1777+
}
1778+
}
17641779
}
17651780

17661781
pub fn total_tokens_num(&self) -> u64 {
@@ -1829,25 +1844,30 @@ impl DocSet {
18291844
let row_id_col = batch[ROW_ID].as_primitive::<datatypes::UInt64Type>();
18301845
let num_tokens_col = batch[NUM_TOKEN_COL].as_primitive::<datatypes::UInt32Type>();
18311846

1832-
let (row_ids, num_tokens) = match is_legacy {
1847+
let (row_ids, num_tokens, inv) = match is_legacy {
18331848
// for legacy format, the row id is doc id,
18341849
// in order to support efficient search, we need to sort the row ids,
18351850
// so that we can use binary search to get num_tokens
1836-
true => row_id_col
1837-
.values()
1838-
.iter()
1839-
.filter_map(|id| {
1840-
if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
1841-
frag_reuse_index_ref.remap_row_id(*id)
1842-
} else {
1843-
Some(*id)
1844-
}
1845-
})
1846-
.zip(num_tokens_col.values().iter())
1847-
.sorted_unstable_by_key(|x| x.0)
1848-
.unzip(),
1851+
true => {
1852+
let (row_ids, num_tokens) = row_id_col
1853+
.values()
1854+
.iter()
1855+
.filter_map(|id| {
1856+
if let Some(frag_reuse_index_ref) = frag_reuse_index.as_ref() {
1857+
frag_reuse_index_ref.remap_row_id(*id)
1858+
} else {
1859+
Some(*id)
1860+
}
1861+
})
1862+
.zip(num_tokens_col.values().iter())
1863+
.sorted_unstable_by_key(|x| x.0)
1864+
.unzip();
1865+
1866+
// the legacy format doesn't need to store the inv
1867+
(row_ids, num_tokens, Vec::new())
1868+
}
18491869
false => {
1850-
let row_ids = row_id_col
1870+
let row_ids: Vec<u64> = row_id_col
18511871
.values()
18521872
.iter()
18531873
.filter_map(|id| {
@@ -1859,14 +1879,24 @@ impl DocSet {
18591879
})
18601880
.collect();
18611881
let num_tokens = num_tokens_col.values().to_vec();
1862-
(row_ids, num_tokens)
1882+
1883+
// build the inv
1884+
let inv = row_ids
1885+
.iter()
1886+
.copied()
1887+
.enumerate()
1888+
.sorted_unstable()
1889+
.map(|(i, row_id)| (row_id, i as u32))
1890+
.collect();
1891+
(row_ids, num_tokens, inv)
18631892
}
18641893
};
18651894

18661895
let total_tokens = num_tokens.iter().map(|&x| x as u64).sum();
18671896
Ok(Self {
18681897
row_ids,
18691898
num_tokens,
1899+
inv,
18701900
total_tokens,
18711901
})
18721902
}
@@ -1901,6 +1931,8 @@ impl DocSet {
19011931
self.num_tokens[doc_id as usize]
19021932
}
19031933

1934+
// this can be used only if it's a legacy format,
1935+
// which store the sorted row ids so that we can use binary search
19041936
#[inline]
19051937
pub fn num_tokens_by_row_id(&self, row_id: u64) -> u32 {
19061938
self.row_ids

rust/lance-index/src/scalar/inverted/wand.rs

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use arrow::array::AsArray;
99
use arrow::datatypes::{Int32Type, UInt32Type};
1010
use arrow_array::{Array, UInt32Array};
1111
use arrow_schema::DataType;
12+
use itertools::Itertools;
13+
use lance_core::utils::address::RowAddress;
1214
use lance_core::utils::mask::RowIdMask;
1315
use lance_core::Result;
1416

@@ -321,17 +323,26 @@ impl<'a, S: Scorer> Wand<'a, S> {
321323
return Ok(vec![]);
322324
}
323325

326+
let avg_posting_length =
327+
self.postings.iter().map(|p| p.list.len()).sum::<usize>() / self.postings.len();
328+
match (mask.max_len(), mask.iter_ids()) {
329+
(Some(num_rows_matched), Some(row_ids))
330+
if num_rows_matched <= avg_posting_length as u64 =>
331+
{
332+
return self.flat_search(params, row_ids, metrics);
333+
}
334+
_ => {}
335+
}
336+
324337
let mut candidates = BinaryHeap::new();
325338
let mut num_comparisons = 0;
326339
while let Some((pivot, doc)) = self.next()? {
327340
self.cur_doc = Some(doc);
328341
num_comparisons += 1;
329342

330-
// if the doc is not located, we need to find the row id
331343
let row_id = match &doc {
332344
DocInfo::Raw(doc) => {
333345
// if the doc is not located, we need to find the row id
334-
// in the doc set. This is a bit slow, but it should be rare.
335346
self.docs.row_id(doc.doc_id)
336347
}
337348
DocInfo::Located(doc) => doc.row_id,
@@ -379,6 +390,90 @@ impl<'a, S: Scorer> Wand<'a, S> {
379390
.collect())
380391
}
381392

393+
fn flat_search(
394+
&mut self,
395+
params: &FtsSearchParams,
396+
row_ids: Box<dyn Iterator<Item = RowAddress> + '_>,
397+
metrics: &dyn MetricsCollector,
398+
) -> Result<Vec<DocCandidate>> {
399+
let limit = params.limit.unwrap_or(usize::MAX);
400+
if limit == 0 {
401+
return Ok(vec![]);
402+
}
403+
404+
// we need to map the row ids to doc ids, and sort them,
405+
// because WAND PostingIterator can't go back to the previous doc id
406+
let doc_ids = row_ids
407+
.filter_map(|row_addr| {
408+
let row_id: u64 = row_addr.into();
409+
self.docs.doc_id(row_id).map(|doc_id| (doc_id, row_id))
410+
})
411+
.sorted_unstable()
412+
.collect::<Vec<_>>();
413+
let is_compressed = matches!(self.postings[0].list, PostingList::Compressed(_));
414+
415+
let mut num_comparisons = 0;
416+
let mut candidates = BinaryHeap::new();
417+
for (doc_id, row_id) in doc_ids {
418+
num_comparisons += 1;
419+
420+
// move all postings to this doc id
421+
self.move_preceding(self.postings.len() - 1, doc_id);
422+
if self.postings.is_empty() {
423+
// no more postings, so we can stop
424+
break;
425+
} else if self.postings[0].doc().map(|d| d.doc_id()) != Some(doc_id) {
426+
// this doc is not in the postings, so we can skip it
427+
continue;
428+
}
429+
430+
let mut pivot = 0;
431+
while pivot + 1 < self.postings.len()
432+
&& self.postings[pivot + 1].doc().map(|d| d.doc_id()) == Some(doc_id)
433+
{
434+
pivot += 1;
435+
}
436+
437+
// check positions
438+
if params.phrase_slop.is_some()
439+
&& !self.check_positions(params.phrase_slop.unwrap() as i32)
440+
{
441+
continue;
442+
}
443+
444+
// score the doc
445+
let doc_length = match is_compressed {
446+
true => self.docs.num_tokens(doc_id as u32),
447+
false => self.docs.num_tokens_by_row_id(row_id),
448+
};
449+
450+
let score = self.score(pivot, doc_length);
451+
let freqs = self
452+
.iter_token_freqs(pivot)
453+
.map(|(token, freq)| (token.to_owned(), freq))
454+
.collect();
455+
456+
if candidates.len() < limit {
457+
candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length)));
458+
} else if score > candidates.peek().unwrap().0 .0.score.0 {
459+
candidates.pop();
460+
candidates.push(Reverse((ScoredDoc::new(row_id, score), freqs, doc_length)));
461+
self.threshold = candidates.peek().unwrap().0 .0.score.0 * params.wand_factor;
462+
}
463+
}
464+
metrics.record_comparisons(num_comparisons);
465+
466+
Ok(candidates
467+
.into_sorted_vec()
468+
.into_iter()
469+
.map(|Reverse((doc, freqs, doc_length))| DocCandidate {
470+
row_id: doc.row_id,
471+
freqs,
472+
doc_length,
473+
})
474+
.collect())
475+
}
476+
382477
// calculate the score of the current document
383478
fn score(&self, pivot: usize, doc_length: u32) -> f32 {
384479
let mut score = 0.0;

0 commit comments

Comments
 (0)