-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathpreprocess.cpp
More file actions
228 lines (213 loc) · 5.95 KB
/
preprocess.cpp
File metadata and controls
228 lines (213 loc) · 5.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#include <vector>
#include <string>
#include <iostream>
#include <cassert>
#include <iomanip>
#include <set>
#include <map>
#include <cmath>
#include <fstream>
#include <algorithm>
#include <queue>
#include <sstream>
using namespace std;
#define rep(i, a, b) for(int i = (a); i < int(b); ++i)
#define rrep(i, a, b) for(int i = (a) - 1; i >= int(b); --i)
#define trav(x, v) for(auto& x : v)
#define sz(x) (int)(x).size()
#define all(v) (v).begin(), (v).end()
#define what_is(x) cout << #x << " is " << x << endl;
typedef double fl;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef vector<pii> vpi;
// Sloppily and quickly parse a double
double parseDouble(const string& str) {
double res = 0, factor = -1;
int exp = 0;
bool foundone = false, neg = false, parseexp = false;
bool efoundone = false, eneg = false;
trav(c, str) {
if (parseexp) {
if (c == '-') {
assert(!efoundone && !eneg);
eneg = true;
} else {
int dig = c - '0';
assert(0 <= dig && dig < 10);
efoundone = true;
exp *= 10;
exp += dig;
}
} else if (c == '-') {
assert(!foundone && !neg && factor == -1);
neg = true;
} else if (c == '.') {
assert(factor == -1);
factor = 1;
} else if (c == 'e' || c == 'E') {
parseexp = true;
} else {
double dig = c - '0';
assert(0 <= dig && dig < 10);
if (factor == -1) {
res *= 10;
res += dig;
} else {
factor *= 0.1;
res += factor * dig;
}
foundone = true;
}
}
assert(foundone);
if (neg) res = -res;
if (parseexp) {
assert(efoundone);
if (eneg) exp = -exp;
res *= pow(10, exp);
}
return res;
}
string tolower(string s){
string ret;
for(auto c : s){
if(c >= 'A' && c <= 'Z')
ret.push_back(c + 'a' - 'A');
else
ret.push_back(c);
}
return ret;
}
string capitalize(string s){
if(s[0] >= 'a' && s[0] <= 'z')
s[0] += 'A' - 'a';
return s;
}
void processWord2Vec(const char* inFile, const char* popFile, const char* outFile, const char* wordlistFile, int modelid, int limit) {
string line;
set<string> wordlist;
ifstream fin(wordlistFile);
if (fin) {
while (getline(fin, line)) {
wordlist.insert(tolower(line));
wordlist.insert(capitalize(tolower(line)));
}
fin.close();
}
else {
cerr << "Warning: missing " << wordlistFile << ", so unable to ensure all words from there are available." << endl;
}
int popcount = 0; // (sorry)
map<string, int> popularWords;
trav(w, wordlist) popularWords[w] = -1;
fin.open(popFile);
assert(fin);
while (getline(fin, line)) {
istringstream iss(line);
string word;
iss >> word;
assert(!popularWords.count(word) || wordlist.count(word));
popularWords[word] = popcount++;
if (sz(popularWords) == limit)
break;
}
trav(w, wordlist) {
if (popularWords[w] == -1)
popularWords[w] = popcount++;
}
assert(sz(popularWords) == popcount);
fin.close();
struct Word {
string word;
float norm;
vector<float> vec;
};
fin.open(inFile);
assert(fin);
vector<Word> words(popcount);
int dim = -1, count = 0;
while (getline(fin, line)) {
size_t ind = line.find(' ');
assert(ind != string::npos && ind != 0);
Word w;
w.word = line.substr(0, ind);
auto indexit = popularWords.find(w.word);
if (indexit == popularWords.end())
continue;
int index = indexit->second;
popularWords.erase(indexit);
double norm = 0;
if (dim != -1)
w.vec.reserve(dim);
while (ind != string::npos) {
size_t ind2 = line.find(' ', ind+1);
string tok = line.substr(ind+1, ind2 == string::npos ? string::npos : ind2 - (ind+1));
double x = parseDouble(tok);
w.vec.push_back((float)x);
norm += x*x;
ind = ind2;
}
if (dim == -1) dim = sz(w.vec);
else assert(sz(w.vec) == dim);
w.norm = (float)norm;
double mu = 1 / sqrt(norm);
trav(x, w.vec) x = (float)(x * mu);
wordlist.erase(w.word);
words[index] = move(w);
count++;
if (count == popcount)
break;
}
fin.close();
if (!wordlist.empty()) {
cerr << "Warning: words not found:" << endl;
for (const string &w : wordlist)
cerr << w << endl;
}
ofstream fout(outFile, ios::binary);
int sentinel = -1;
int version = 1;
fout.write((char*)&sentinel, sizeof sentinel);
fout.write((char*)&version, sizeof version);
fout.write((char*)&modelid, sizeof modelid);
fout.write((char*)&count, sizeof count);
fout.write((char*)&dim, sizeof dim);
trav(w, words) {
int len = sz(w.word);
if (!len) continue;
fout.write((char*)&len, sizeof len);
fout.write(w.word.data(), len);
fout.write((char*)&w.norm, sizeof(float));
fout.write((char*)w.vec.data(), dim * sizeof(float));
}
fout.close();
}
int main(int argc, char **argv) {
if (argc != 6) {
cerr << "Usage: " << argv[0] << " <word2vec .txt file> <popularity .txt file> <model id> <limit> <outfile.bin>" << endl;
cerr << endl;
cerr << "* The word2vec file should be a list of lines of the form \"word a_1 a_2 ... a_k\"," << endl;
cerr << " where k is the dimension of the word2vec embedding, a_i are real numbers in decimal form," << endl;
cerr << " and words are lower-case with spaces replaced by underscores." << endl;
cerr << endl;
cerr << "* The popularity file contains words to be included, in order of decreasing commonness." << endl;
cerr << " Only the first token of every line is considered; thus, word2vec txt files can be used here as well." << endl;
cerr << endl;
cerr << "* Additionally, if wordlist.txt exists, it is prepended to the popularity file." << endl;
cerr << " It is intended to contain all the words from the game." << endl;
cerr << endl;
cerr << "* The model id is an arbitrary integer representing the model." << endl;
cerr << endl;
cerr << "* The limit indicates the number of words from the popularity file to use. 0 = unlimited." << endl;
cerr << " Around 50,000 is reasonable." << endl;
return 1;
}
const char* inFile = argv[1];
const char* popFile = argv[2];
int modelid = atoi(argv[3]);
int limit = atoi(argv[4]);
const char* outFile = argv[5];
processWord2Vec(inFile, popFile, outFile, "wordlist.txt", modelid, limit);
}