-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenizer.py
More file actions
174 lines (142 loc) · 5.63 KB
/
tokenizer.py
File metadata and controls
174 lines (142 loc) · 5.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Character-level and BPE tokenizers built from scratch."""
import json
import os
import re
from collections import Counter
class CharTokenizer:
"""Maps each unique character to an integer."""
def __init__(self):
self.char_to_id: dict[str, int] = {}
self.id_to_char: dict[int, str] = {}
def train(self, text: str):
chars = sorted(set(text))
self.char_to_id = {ch: i for i, ch in enumerate(chars)}
self.id_to_char = {i: ch for ch, i in self.char_to_id.items()}
@property
def vocab_size(self) -> int:
return len(self.char_to_id)
def encode(self, text: str) -> list[int]:
return [self.char_to_id[ch] for ch in text]
def decode(self, ids: list[int]) -> str:
return "".join(self.id_to_char[i] for i in ids)
def save(self, path: str):
with open(path, "w") as f:
json.dump({"char_to_id": self.char_to_id}, f)
def load(self, path: str):
with open(path, "r") as f:
data = json.load(f)
self.char_to_id = data["char_to_id"]
self.id_to_char = {int(i): ch for ch, i in self.char_to_id.items()}
class BPETokenizer:
"""Byte-pair encoding tokenizer trained from scratch.
Starts with individual characters as the vocabulary, then iteratively
merges the most frequent adjacent pair into a new token.
"""
def __init__(self):
self.merges: list[tuple[str, str]] = [] # ordered merge rules
self.vocab: dict[str, int] = {} # token -> id
self.id_to_token: dict[int, str] = {}
def train(self, text: str, vocab_size: int = 512, verbose: bool = True):
"""Train BPE merges until we reach vocab_size."""
# Start with character-level tokens
chars = sorted(set(text))
self.vocab = {ch: i for i, ch in enumerate(chars)}
self.id_to_token = {i: ch for ch, i in self.vocab.items()}
# Pre-tokenize: split into words (keep whitespace attached)
words = re.findall(r"\S+|\s+", text)
# Represent each word as a tuple of characters
word_freqs: dict[tuple[str, ...], int] = Counter(
tuple(w) for w in words
)
num_merges = vocab_size - len(chars)
if verbose:
print(f"Base vocab: {len(chars)} chars, planning {num_merges} merges")
for i in range(num_merges):
# Count all adjacent pairs
pair_counts: Counter[tuple[str, str]] = Counter()
for word, freq in word_freqs.items():
for j in range(len(word) - 1):
pair_counts[(word[j], word[j + 1])] += freq
if not pair_counts:
break
# Most frequent pair
best_pair = pair_counts.most_common(1)[0]
pair, count = best_pair
if count < 2:
break
# Create merged token
merged = pair[0] + pair[1]
new_id = len(self.vocab)
self.vocab[merged] = new_id
self.id_to_token[new_id] = merged
self.merges.append(pair)
if verbose and (i + 1) % 50 == 0:
print(
f" merge {i+1}/{num_merges}: "
f"'{pair[0]}' + '{pair[1]}' -> '{merged}' (count={count})"
)
# Apply merge to all words
new_word_freqs: dict[tuple[str, ...], int] = {}
for word, freq in word_freqs.items():
new_word = self._apply_merge(word, pair, merged)
new_word_freqs[new_word] = (
new_word_freqs.get(new_word, 0) + freq
)
word_freqs = new_word_freqs
if verbose:
print(f"Final vocab size: {len(self.vocab)}")
@staticmethod
def _apply_merge(
word: tuple[str, ...], pair: tuple[str, str], merged: str
) -> tuple[str, ...]:
"""Replace all occurrences of pair in word with merged token."""
new_word: list[str] = []
i = 0
while i < len(word):
if i < len(word) - 1 and word[i] == pair[0] and word[i + 1] == pair[1]:
new_word.append(merged)
i += 2
else:
new_word.append(word[i])
i += 1
return tuple(new_word)
@property
def vocab_size(self) -> int:
return len(self.vocab)
def encode(self, text: str) -> list[int]:
"""Encode text by applying learned merges in order."""
# Start as characters
tokens = list(text)
# Apply each merge rule in order
for pair in self.merges:
merged = pair[0] + pair[1]
new_tokens: list[str] = []
i = 0
while i < len(tokens):
if (
i < len(tokens) - 1
and tokens[i] == pair[0]
and tokens[i + 1] == pair[1]
):
new_tokens.append(merged)
i += 2
else:
new_tokens.append(tokens[i])
i += 1
tokens = new_tokens
return [self.vocab[t] for t in tokens]
def decode(self, ids: list[int]) -> str:
return "".join(self.id_to_token[i] for i in ids)
def save(self, path: str):
with open(path, "w") as f:
json.dump(
{"merges": self.merges, "vocab": self.vocab},
f,
indent=2,
)
def load(self, path: str):
with open(path, "r") as f:
data = json.load(f)
self.merges = [tuple(m) for m in data["merges"]]
self.vocab = data["vocab"]
self.id_to_token = {int(i): t for t, i in self.vocab.items()}