-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessing.py
More file actions
134 lines (104 loc) · 4.01 KB
/
preprocessing.py
File metadata and controls
134 lines (104 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re
import unicodedata
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple, TypeAlias
from sklearn.model_selection import train_test_split
from constants import DECODE, ENCODE, SEED
from utils import CMUDictType
Graphemes: TypeAlias = List[Tuple[int, str, List[List[str]]]]
RefMap: TypeAlias = Dict[int, List[List[int]]]
@dataclass(frozen=True)
class TokenConfig:
# Encoder side
encode_vocab: List[str]
encode_char_to_id: Dict[str, int]
# Decoder side
decode_vocab: List[str]
decode_char_to_id: Dict[str, int]
decode_id_to_char: Dict[int, str]
@dataclass(frozen=True)
class PhonemePair:
index: int
grapheme: List[int]
phoneme: List[int]
def build_ref_map(
train_pairs: List[PhonemePair],
val_pairs: List[PhonemePair],
test_pairs: List[PhonemePair],
graphemes: Graphemes,
) -> RefMap:
ref_map: RefMap = dict()
for phoneme_pair in train_pairs + val_pairs + test_pairs:
idx = phoneme_pair.index
phoneme_variants = graphemes[idx][2]
if len(phoneme_variants) <= 1:
continue
if idx not in ref_map:
ref_map[idx] = []
ref_map[idx].append(phoneme_pair.phoneme)
return ref_map
def generate_pairs(graphemes: Graphemes, config: TokenConfig) -> List[PhonemePair]:
phoneme_pairs: List[PhonemePair] = []
for index, (index, word, phonemes) in enumerate(graphemes):
word_ids = [config.encode_char_to_id[char] for char in word]
for phoneme in phonemes:
phoneme_ids = [config.decode_char_to_id[char] for char in phoneme]
pair = PhonemePair(index=index, grapheme=word_ids, phoneme=phoneme_ids)
phoneme_pairs.append(pair)
return phoneme_pairs
def get_id_to_char(char_to_id: Dict[str, int]) -> Dict[int, str]:
return {i: c for c, i in char_to_id.items()}
def get_char_to_id(vocab: List[str]) -> Dict[str, int]:
return {c: i for i, c in enumerate(vocab)}
def parse_cmu_dict(cmu_dict: CMUDictType) -> Tuple[Graphemes, TokenConfig]:
encode_vocab: Set[str] = set()
decode_vocab: Set[str] = set()
graphemes: Graphemes = []
index = 0
for word, phonemes in cmu_dict.items():
word = unicodedata.normalize("NFC", word.strip())
# Skip non-alphabetic entries
if not re.match(r"^[a-z'-]+$", word):
continue
phonemes_clean = []
for phoneme in phonemes:
new_phoneme = (
["<BOS>"] + [re.sub(r"\d$", "", p) for p in phoneme] + ["<EOS>"]
)
for p in new_phoneme:
decode_vocab.add(p)
phonemes_clean.append(new_phoneme)
cmu_tuple = (index, word, phonemes_clean)
graphemes.append(cmu_tuple)
for char in word:
if char not in encode_vocab:
encode_vocab.add(char)
index += 1
encode_vocab = encode_vocab.union(ENCODE)
encode_vocab = sorted(list(encode_vocab))
encode_char_to_id = get_char_to_id(vocab=encode_vocab)
decode_vocab = decode_vocab.union(DECODE)
decode_vocab = sorted(list(decode_vocab))
decode_char_to_id = get_char_to_id(vocab=decode_vocab)
decode_id_to_char = get_id_to_char(char_to_id=decode_char_to_id)
return graphemes, TokenConfig(
encode_vocab=encode_vocab,
encode_char_to_id=encode_char_to_id,
decode_vocab=decode_vocab,
decode_char_to_id=decode_char_to_id,
decode_id_to_char=decode_id_to_char,
)
def split_and_generate_pairs(
graphemes: Graphemes,
config: TokenConfig,
test_size: float = 0.2,
val_size: float = 0.5,
) -> Tuple[List[PhonemePair], List[PhonemePair], List[PhonemePair]]:
train_g, test_g = train_test_split(
graphemes, test_size=test_size, random_state=SEED
)
test_g, val_g = train_test_split(test_g, test_size=val_size, random_state=SEED)
train_pairs = generate_pairs(train_g, config)
val_pairs = generate_pairs(val_g, config)
test_pairs = generate_pairs(test_g, config)
return train_pairs, val_pairs, test_pairs