Skip to content

Commit ed932b6

Browse files
committed
Normalize document length
1 parent 89a8964 commit ed932b6

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

rust/search/search.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,10 @@ pub fn colbert_score_reduce(token_scores: &Tensor, attention_mask: &Tensor) -> T
267267
// Padded doc tokens were set to -9999.0 above; zero them out before aggregation.
268268
let (max_scores_d_to_q, _) = masked_scores.max_dim(2, false);
269269
let valid_doc_mask = attention_mask.to_kind(Kind::Float);
270-
let d_to_q = (max_scores_d_to_q * valid_doc_mask).sum_dim_intlist(-1, false, Kind::Float);
270+
let doc_lengths = valid_doc_mask.sum_dim_intlist(-1, false, Kind::Float).clamp_min(1.0);
271+
let d_to_q = (max_scores_d_to_q * &valid_doc_mask)
272+
.sum_dim_intlist(-1, false, Kind::Float)
273+
/ doc_lengths;
271274

272275
// Average the two directions to obtain a bidirectional MaxSim score.
273276
(q_to_d + d_to_q) / 2.0

0 commit comments

Comments
 (0)