-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
239 lines (192 loc) · 12.4 KB
/
dataset.py
File metadata and controls
239 lines (192 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import random
import torch
from torch.utils.data import Dataset
import argparse
import numpy as np
"""
The input-output pairs (x, y) of the NameDataset are of the following form:
x: Where was Khatchig Mouradian born?⁇Lebanon⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□⁇Lebanon⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
x: Where was Jacob Henry Studer born?⁇Columbus⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y: □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□⁇Columbus⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
Using the PAD_CHAR characters in y before the ⁇[place] keeps the trainer from
optimizing the model to predict the question, "Where was...".
Note that the NameDataset should take the pretraining_dataset defined in run.py
as an input. This is to allow the vocab specification of the NameDataset to be
the same as that of the pretraining dataset.
You don't need to implement anything in NameDataset.
"""
class NameDataset(Dataset):
def __init__(self, data, pretraining_dataset):
self.MASK_CHAR = u"\u2047" # the doublequestionmark character, for mask
self.PAD_CHAR = u"\u25A1" # the empty square character, for pad
self.itos = pretraining_dataset.itos
self.stoi = pretraining_dataset.stoi
self.block_size = pretraining_dataset.block_size
self.data = list(data.encode('utf-8').decode('ascii', errors='ignore').split('\n'))
def __len__(self):
# returns the length of the dataset
return len(self.data) - 1
def __getitem__(self, idx):
inp, oup = self.data[idx].split('\t')
x = inp + self.MASK_CHAR + oup + self.MASK_CHAR
x = x + self.PAD_CHAR*(self.block_size - len(x))
y = self.PAD_CHAR*(len(inp)-1) + x[len(inp):]
x = x[:-1]
x = torch.tensor([self.stoi[c] for c in x], dtype=torch.long)
y = torch.tensor([self.stoi[c] for c in y], dtype=torch.long)
return x, y
"""
Write a class that yields examples of a simplified span corruption objective.
Do not change the signature of the __init__ or __getitem__ functions.
--------------
Vocabulary Specification
Your vocabulary is to be accessible via two dictionaries:
self.stoi: a dictionary from characters in the vocabulary to indices of type
int
self.itos: a dictionary from indices of type int to characters in the
vocabulary
Your vocabulary must have the following form:
Identifier 0 must be assigned to the unicode element u"\u25A1".
This is the empty_square_character.
Further, let self.PAD_CHAR = u"\u25A1"
Identifier 1 must be assigned to the unicode element u"\u2047".
This is the doublequestionmark character, which we'll use
as a sentinel to represent that text is missing from the input
Further, let self.MASK_CHAR = u"\u2047"
Identifiers 2, ..., len(self.itos)-1 should be the sorted list of characters
that appear in the data argument.
--------------
Masking Specification
The __getitem__ function takes an index and returns a data point (x, y) where
x and y are Long tensors of length self.block_size. x encodes the input
sequence, and y encodes the output sequence.
0. Use the idx argument of __getitem__ to retrieve the element of self.data
at the given index. We'll call the resulting data entry a document.
1. Randomly truncate the document to a length no less than 4 characters,
and no more than int(self.block_size*7/8) characters.
- IMPORTANT: You are free to decide how to perform this random truncation, but
make sure that the length is picked _randomly_ (every possible length from 4
to int(self.block_size*7/8) has a chance of being picked) for full credit.
2. Now, break the (truncated) document into three substrings:
[prefix] [masked_content] [suffix]
In other words, choose three strings prefix, masked_content and suffix
such that prefix + masked_content + suffix = [the original document].
The length of [masked_content] should be random, and 1/4 the length of the
truncated document on average.
- IMPORTANT: You are free to decide how to perform this operation, but
make sure that the length is picked _randomly_ (has a chance of being more or
less than 1/4 the length of the truncated document).
3. Rearrange these substrings into the following form:
[prefix] MASK_CHAR [suffix] MASK_CHAR [masked_content] MASK_CHAR [pads]
This resulting string, denoted masked_string, serves as the output example.
Here MASK_CHAR is the masking character defined in Vocabulary Specification,
and [pads] is a string of repeated PAD_CHAR characters chosen so that the
entire string is of length self.block_size.
Intuitively, the [masked_content], a string, is removed from the document and
replaced with MASK_CHAR (the masking character defined in Vocabulary
Specification). After the suffix of the string, the MASK_CHAR is seen again,
followed by the content that was removed, and the padding characters.
4. We now use masked_string to construct the input and output example pair. To
do so, simply take the input string to be masked_string[:-1], and the output
string to be masked_string[1:]. In other words, for each character, the goal is
to predict the next character in the masked string.
5. Making use of the vocabulary that you defined, encode the resulting input
and output strings as Long tensors and return the resulting data point.
----------------
Here are some examples of input-output pairs (x, y):
x: Khatchig Mouradian. Khatchig Mouradian is a jour⁇and tran⁇nalist, writer ⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y: hatchig Mouradian. Khatchig Mouradian is a jour⁇and tran⁇nalist, writer ⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
x: Jaco⁇enry ⁇b H⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y: aco⁇enry ⁇b H⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
x: John Stephen. Born in Glasgow, Steph⁇lder's apprentice on⁇en became a we⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
y: ohn Stephen. Born in Glasgow, Steph⁇lder's apprentice on⁇en became a we⁇□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□
"""
class CharCorruptionDataset(Dataset):
def __init__(self, data, block_size):
self.MASK_CHAR = u"\u2047" # the doublequestionmark character, for mask
self.PAD_CHAR = u"\u25A1" # the empty square character, for pad
chars = list(sorted(list(set(data))))
assert self.MASK_CHAR not in chars
assert self.PAD_CHAR not in chars
chars.insert(0, self.MASK_CHAR)
chars.insert(0, self.PAD_CHAR)
self.stoi = { ch:i for i,ch in enumerate(chars) }
self.itos = { i:ch for i,ch in enumerate(chars) }
data_size, vocab_size = len(data), len(chars)
print('data has %d characters, %d unique.' % (data_size, vocab_size))
self.block_size = block_size
self.max_context_size = int(block_size*3/4)
self.masking_percent = 0.25
self.vocab_size = vocab_size
self.data = data.split('\n')
self.item = None
def __len__(self):
# returns the length of the dataset
return len(self.data)
def __getitem__(self, idx):
document = self.data[idx]
if len(document) == 0:
x = self.PAD_CHAR * (self.block_size - 1)
y = self.PAD_CHAR * (self.block_size - 1)
x = torch.tensor([self.stoi[c] for c in x], dtype=torch.long)
y = torch.tensor([self.stoi[c] for c in y], dtype=torch.long)
else:
trunc_length = random.randint(4, int(self.block_size * 7/8))
if trunc_length > len(document):
truncated_document = document
else:
truncated_document = document[0:trunc_length]
masked_length = np.random.poisson(int(self.masking_percent * len(truncated_document)))
if masked_length > len(truncated_document):
masked_length = len(truncated_document)
diff = len(truncated_document) - masked_length
try:
start_idx = random.randint(0, diff)
except ValueError:
print(f' document length = {len(document)}')
print(f"truncated doc length = {len(truncated_document)}")
print(f"masked length = {masked_length}")
print(f'diff = {diff}')
start_idx = 0
prefix = truncated_document[0:start_idx]
masked_content = truncated_document[start_idx:start_idx + masked_length]
suffix = truncated_document[start_idx + masked_length:]
masked_string = prefix + self.MASK_CHAR + suffix + self.MASK_CHAR + masked_content + self.MASK_CHAR
num_pads = self.block_size - len(masked_string)
masked_string += self.PAD_CHAR * num_pads
x = masked_string[:-1]
y = masked_string[1:]
x = torch.tensor([self.stoi[c] for c in x], dtype=torch.long)
y = torch.tensor([self.stoi[c] for c in y], dtype=torch.long)
return x, y
"""
Code under here is strictly for your debugging purposes; feel free to modify
as desired.
"""
if __name__ == '__main__':
argp = argparse.ArgumentParser()
argp.add_argument('dataset_type', help="Type of dataset to sample from."
"Options: namedata, charcorruption.",
choices=["namedata", "charcorruption"])
args = argp.parse_args()
if args.dataset_type == 'namedata':
# Even if it hasn't been implemented, we use it to define the vocab
corruption_dataset = CharCorruptionDataset(open('./../data/wiki.txt', encoding='utf-8').read(), 128)
# Make the name dataset
name_dataset = NameDataset(open('./../data/birth_places_train.tsv', encoding='utf-8').read(),
corruption_dataset)
for _, example in zip(range(4), name_dataset):
x, y = example
print('x:', ''.join([name_dataset.itos[int(c)] for c in x]))
print('y:', ''.join([name_dataset.itos[int(c)] for c in y]))
pass
elif args.dataset_type == 'charcorruption':
corruption_dataset = CharCorruptionDataset(open('./../data/wiki.txt', encoding='utf-8').read(), 128)
for _, example in zip(range(4), corruption_dataset):
x, y = example
print('x:', ''.join([corruption_dataset.itos[int(c)] for c in x]))
print('y:', ''.join([corruption_dataset.itos[int(c)] for c in y]))
else:
raise ValueError("Unknown dataset type in command line args: {}"
.format(args.dataset_type))