-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpatch_colbert_modeling.py
More file actions
80 lines (62 loc) · 2.91 KB
/
patch_colbert_modeling.py
File metadata and controls
80 lines (62 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import re
target_path = r"C:\Users\stdag\AppData\Local\uv\cache\archive-v0\mdpp6F6wi_f8_CcbVUqKb\Lib\site-packages\colbert\modeling\colbert.py"
if not os.path.exists(target_path):
print(f"Error: File not found at {target_path}")
exit(1)
with open(target_path, 'r', encoding='utf-8') as f:
content = f.read()
fallback_code = """
@classmethod
def segmented_maxsim(cls, scores, lengths):
import torch
# Pure Python fallback for segmented_maxsim
# scores: [total_embeddings] (flattened scores of all query tokens against all doc embeddings?)
# lengths: [num_docs] (number of embeddings per doc)
# Wait, let's check usage in colbert_score_packed.
# scores = D_packed @ Q.unsqueeze(2)
# scores is [total_embeddings, query_len, 1] or similar?
# Actually, let's look at colbert_score_packed implementation in colbert.py if possible.
# But assuming standard MaxSim:
# For each document, we have a set of embeddings.
# We compute dot products with query embeddings.
# Then for each query term, we take max over doc embeddings.
# Then sum over query terms.
# Here 'scores' seems to be the dot products already?
# If so, it's likely [total_embeddings, query_len].
# We need to segment 'scores' by 'lengths' (per doc), take max(dim=0) for each segment, then sum.
# This is slow in Python loop.
results = []
offset = 0
lengths_cpu = lengths.cpu().long()
for i in range(len(lengths)):
length = lengths_cpu[i].item()
# doc_scores: [length, query_len]
doc_scores = scores[offset : offset + length]
# Max over document embeddings (dim 0)
# max_scores: [query_len]
if length > 0:
max_scores, _ = doc_scores.max(dim=0)
# Sum over query terms
total_score = max_scores.sum()
else:
total_score = torch.tensor(0.0, device=scores.device)
results.append(total_score)
offset += length
return torch.stack(results)
"""
# Remove existing segmented_maxsim if present
if "def segmented_maxsim" in content:
print("segmented_maxsim exists. Removing old version...")
pattern = re.compile(r" @classmethod\s+def segmented_maxsim.+?return .+?\n", re.DOTALL)
content = pattern.sub("", content)
# Now inject the new one
class_def = "class ColBERT(BaseColBERT):"
if class_def in content:
parts = content.split(class_def)
new_content = parts[0] + class_def + "\n" + fallback_code + parts[1]
with open(target_path, 'w', encoding='utf-8') as f:
f.write(new_content)
print("Successfully patched colbert.py with new code")
else:
print("Error: Could not find class ColBERT definition.")