diff --git a/ro_diacritics/diacritcs_train.py b/ro_diacritics/diacritcs_train.py index d9c8cef..6c597fb 100644 --- a/ro_diacritics/diacritcs_train.py +++ b/ro_diacritics/diacritcs_train.py @@ -1,3 +1,5 @@ +"""Training, evaluation and prediction routines for the diacritics model.""" + from pathlib import Path import numpy as np @@ -17,7 +19,21 @@ def train( epochs=10, checkpoint_file=None, ): - + """Train *model* with early stopping and periodic validation. + + Saves the best checkpoint (by validation accuracy) to *checkpoint_file* + whenever a new best is found. Training stops early when validation + accuracy has not improved for *patience* consecutive evaluation steps. + + :param model: :class:`~diacritics_model.Diacritics` model to train. + :param loss_func: Loss criterion (e.g. :class:`torch.nn.BCEWithLogitsLoss`). + :param train_dataloader: :class:`~torch.utils.data.DataLoader` for the + training set. + :param valid_dataloader: :class:`~torch.utils.data.DataLoader` for the + validation set, or ``None`` to skip validation. + :param epochs: Maximum number of training epochs. + :param checkpoint_file: Path where the best model checkpoint is saved. + """ optimizer = optim.Adam(model.parameters(), lr=0.001) device = next(model.parameters()).device print(f"{device} device for training") @@ -36,6 +52,17 @@ def train( patience = 3 def evaluate_step(step, epoch, running_loss, max_eval_steps=None): + """Run one evaluation pass on the validation set and update history. + + Saves a new checkpoint when validation accuracy improves. Increments + the non-improving counter otherwise (used for early stopping). + + :param step: Current training step within the epoch. + :param epoch: Current epoch index (0-based). + :param running_loss: Cumulative training loss up to *step*. + :param max_eval_steps: Cap on the number of validation batches to + process (``None`` means evaluate all). + """ nonlocal best_acc, best_acc_epoch, nr_non_improving if valid_dataloader is None: return @@ -142,7 +169,17 @@ def evaluate_step(step, epoch, running_loss, max_eval_steps=None): def evaluate(model, dataloader: DataLoader, loss_func, epoch=None, max_eval_steps=None): - # print("***** Running prediction *****") + """Evaluate *model* on *dataloader* and return loss, accuracy and F1. + + :param model: :class:`~diacritics_model.Diacritics` model to evaluate. + :param dataloader: :class:`~torch.utils.data.DataLoader` for the + evaluation set. + :param loss_func: Loss criterion matching the one used during training. + :param epoch: Current epoch index for logging (``None`` = test evaluation). + :param max_eval_steps: Maximum number of batches to process before + stopping early (``None`` = full evaluation). + :return: Tuple of ``(eval_loss, eval_acc, f1_metrics)``. + """ model.eval() predict_out = [] all_label_ids = [] @@ -203,7 +240,14 @@ def evaluate(model, dataloader: DataLoader, loss_func, epoch=None, max_eval_step def predict(model, dataloader: DataLoader): - # print("***** Running prediction *****") + """Run inference with *model* over *dataloader* and return softmax scores. + + :param model: :class:`~diacritics_model.Diacritics` model in eval mode. + :param dataloader: :class:`~torch.utils.data.DataLoader` yielding input + triples ``(char_input, word_emb, sentence_emb)``. + :return: List of per-sample softmax probability vectors (one list of + floats per sample, length = number of classes). + """ model.eval() predict_out = [] diff --git a/ro_diacritics/diacritics_dataset.py b/ro_diacritics/diacritics_dataset.py index f900e92..8a0eceb 100644 --- a/ro_diacritics/diacritics_dataset.py +++ b/ro_diacritics/diacritics_dataset.py @@ -1,23 +1,75 @@ +"""Dataset and vocabulary classes for Romanian diacritics restoration.""" + +import logging +import os import pickle as pkl from collections import Counter from pathlib import Path +import fasttext +import fasttext.util import numpy as np import torch from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer from nltk.tokenize import sent_tokenize, word_tokenize from torch.utils.data import IterableDataset -from torchtext.vocab import FastText from .diacritics_utils import ( correct_diacritics, remove_diacritics, has_interesting_chars, DIACRITICS_CANDIDATES, + LOG_NAME, ) +logger = logging.getLogger(LOG_NAME) + +FASTTEXT_CACHE_DIR = ".model" +FASTTEXT_MODEL_FILE = "cc.ro.300.bin" + + +def _load_fasttext_model(): + """Download (if necessary) and load the Romanian fastText binary model. + + The model file is cached in *FASTTEXT_CACHE_DIR*. ``fasttext.util`` + always downloads to the current working directory, so this function + temporarily changes the directory and restores it afterwards. + + :return: Loaded :class:`fasttext.FastText._FastText` model object. + """ + cache_dir = Path(FASTTEXT_CACHE_DIR) + cache_dir.mkdir(exist_ok=True) + model_path = cache_dir / FASTTEXT_MODEL_FILE + + if not model_path.exists(): + prev_dir = os.getcwd() + try: + os.chdir(cache_dir) + fasttext.util.download_model("ro", if_exists="ignore") + finally: + os.chdir(prev_dir) + + return fasttext.load_model(str(model_path)) + class DiacriticsVocab: + """Word-level vocabulary with pre-computed fastText word embeddings. + + For each token in *distinct_tokens* the corresponding 300-dimensional + fastText word vector is stored. Because fastText supports sub-word + information, out-of-vocabulary words still receive a meaningful vector. + + :param distinct_tokens: :class:`~collections.Counter` of token frequencies + collected from the training corpus. + :param max_vocab: Maximum number of vocabulary tokens (most frequent ones + are kept). + :param pad_token: Integer index used for the padding pseudo-token. + :param unk_token: Integer index used for the unknown-word pseudo-token. + :param max_char_vocab: Maximum Unicode code-point treated by the character + vocabulary; code-points above this are mapped to *overflow_char*. + :param overflow_char: Integer encoding used for out-of-range characters. + """ + def __init__( self, distinct_tokens: Counter, @@ -37,7 +89,8 @@ def __init__( self.max_char_vocab = max_char_vocab self.overflow_char = overflow_char - embedding = FastText("ro") + ft = _load_fasttext_model() + distinct_tokens = dict(distinct_tokens.most_common(max_vocab)) self.vocab["itos"] = dict(enumerate(distinct_tokens, 1)) @@ -46,30 +99,49 @@ def __init__( self.vocab["stoi"] = {v: k for k, v in self.vocab["itos"].items()} - fasttext_counts = Counter([remove_diacritics(word) for word in embedding.stoi]) - - for ( - word, - index, - ) in embedding.stoi.items(): # Aggregate embeddings with diacritics - word = remove_diacritics(word) - if word in self.vocab["stoi"]: - idx = self.vocab["stoi"][word] - self.vocab["vectors"][idx] = ( - self.vocab["vectors"][idx] + embedding.vectors[index] - ) - - for word, index in self.vocab["stoi"].items(): - if fasttext_counts[word] > 1: - self.vocab["vectors"][index] = ( - self.vocab["vectors"][index] / fasttext_counts[word] - ) + for word, idx in self.vocab["stoi"].items(): + if word in ("", ""): + continue + vec = ft.get_word_vector(word) + self.vocab["vectors"][idx] = torch.tensor(vec, dtype=torch.float32) def encode_char(self, c): + """Return the integer encoding for character *c*. + + Characters whose Unicode code-point exceeds *max_char_vocab* are + mapped to *overflow_char*. + + :param c: Single character to encode. + :return: Integer code-point or *overflow_char*. + """ return ord(c) if ord(c) <= self.max_char_vocab else self.overflow_char class DiacriticsDataset(IterableDataset): + """Iterable PyTorch dataset for Romanian diacritics restoration. + + Reads tokenised Romanian text and yields *(inputs, label)* tuples where + *inputs* is a triple of ``(char_tensor, word_embedding, sentence_embedding)`` + and *label* is a one-hot tensor of shape ``(3,)`` representing the correct + diacritic class for a candidate character position. + + :param data: Path to a plain-text file, a ``.pkl`` cache file produced by + a previous run, or a raw text string. Pass an empty string when only + running inference with a pre-built *diacritics_vocab*. + :param character_window: Half-width of the character context window (the + total window is ``2 * character_window + 1`` characters). + :param sentence_window: Width of the sliding sentence window in words. + :param min_line_length: Lines (or sentences) shorter than this are + discarded during loading. + :param max_vocab: Maximum vocabulary size (most frequent tokens kept). + :param max_char_vocab: Maximum Unicode code-point handled by the character + vocabulary; higher code-points are mapped to *overflow_char*. + :param overflow_char: Fallback integer for out-of-range characters. + :param diacritics_vocab: Pre-built :class:`DiacriticsVocab` to reuse + (e.g. loaded from a saved checkpoint). When ``None`` a new vocabulary + is built from *data* using fastText embeddings. + """ + def __init__( self, data, @@ -81,23 +153,13 @@ def __init__( overflow_char=255, diacritics_vocab: DiacriticsVocab = None, ): - """ - - :param data: Textfile, Pickle file or raw text - :param character_window: - :param sentence_window: - :param min_line_length: - :param max_vocab: - :param max_char_vocab: - :param overflow_char: - :param diacritics_vocab: - """ self.character_window = character_window self.sentence_window = sentence_window self.min_line_length = min_line_length self.texts, distinct_tokens = self.load_texts(data, self.min_line_length) self.max_vocab = max_vocab self.pad_character = 0 + self._sent_tokenizer = PunktSentenceTokenizer() if diacritics_vocab is None: self.max_char_vocab = max_char_vocab @@ -125,21 +187,29 @@ def __init__( @property def vocab(self): + """Return the underlying vocabulary dictionary (``itos``, ``stoi``, ``vectors``).""" return self.diacritics_vocab.vocab def encode_char(self, c): + """Encode a single character using the vocabulary's character mapping. + + :param c: Single character to encode. + :return: Integer code-point, capped at *max_char_vocab*. + """ return self.diacritics_vocab.encode_char(c) def __iter__(self): - return self.parse_text() + """Yield training samples by iterating over all stored sentences.""" + yield from self.parse_text() @staticmethod def get_label(original_char): - """ - :param original_char: lowercase diacritics char - :return: 0 if no change (not turing into diacritics), - 1 if changed to first candidate of diacritics - 2 if changed to second candidate of diacritics (a has 2 candidates) + """Convert a (possibly diacritical) character to a one-hot class label. + + :param original_char: Lowercase character from the original text. + :return: One-hot :class:`torch.Tensor` of shape ``(3,)`` where index 0 + means "no diacritic", 1 means "first candidate diacritic" (ă, î, + ș, ț), and 2 means "second candidate diacritic" (â). """ diacritic_to_label = { "ă": 1, @@ -159,6 +229,15 @@ def get_label(original_char): @staticmethod def get_char_from_label(original_char, label): + """Map a base character and a predicted class label back to the + (possibly diacritical) output character. + + :param original_char: Lowercase base character (e.g. ``'a'``). + :param label: Predicted integer class (0 = no change, 1 = first + candidate, 2 = second candidate). + :return: The correctly diacritised character, or *original_char* if no + transformation applies. + """ if original_char not in DIACRITICS_CANDIDATES: return original_char @@ -175,6 +254,25 @@ def get_char_from_label(original_char, label): return original_char def get_char_input(self, line, line_orig, token_idx): + """Build character-level input encodings for every diacritic-candidate + position in the token at *token_idx*. + + For each candidate character a fixed-width context window of encoded + characters is produced: *character_window* characters to the left of + the candidate (padded with zeros), the candidate itself, and + *character_window* characters to the right (padded with zeros). + + :param line: List of word strings (diacritics stripped, lower-cased). + :param line_orig: List of word strings in their original form (used to + determine the ground-truth label). + :param token_idx: Index of the target token inside *line*. + :return: Tuple of three lists aligned by position: + + * ``encoded_list`` – list of integer lists, each of length + ``2 * character_window + 1``. + * ``labels`` – list of one-hot label tensors. + * ``char_positions`` – list of intra-word character offsets. + """ prefix_s = " ".join(line[:token_idx]) suffix_s = " ".join(line[token_idx + 1 :]) word = line[token_idx] @@ -205,6 +303,14 @@ def get_char_input(self, line, line_orig, token_idx): return encoded_list, labels, char_positions def get_word_emb(self, word): + """Look up the pre-computed word embedding for *word*. + + Falls back to the ```` vector when *word* is not in the + vocabulary. + + :param word: Token string to look up. + :return: Embedding :class:`torch.Tensor` of shape ``(word_embedding_size,)``. + """ if word in self.vocab["stoi"]: idx = self.vocab["stoi"][word] embed = self.vocab["vectors"][idx] @@ -214,6 +320,15 @@ def get_word_emb(self, word): return embed def get_sentence_emb(self, sentence): + """Build a sentence-level embedding matrix for a window of words. + + Short sentences are right-padded with the ```` vector so the + result always has *sentence_window* rows. + + :param sentence: List of token strings. + :return: :class:`torch.Tensor` of shape + ``(sentence_window, word_embedding_size)``. + """ encoded = list(map(self.get_word_emb, sentence)) if len(encoded) < self.sentence_window: encoded = encoded + [self.get_word_emb("")] * ( @@ -222,7 +337,15 @@ def get_sentence_emb(self, sentence): return torch.stack(encoded) def parse_text(self): - "Generates one sample of data" + """Generate *(inputs, label)* training samples from :attr:`texts`. + + For every sentence in the corpus and every diacritic-candidate + character, a triple of ``(char_tensor, word_embedding, + sentence_embedding)`` is yielded together with the corresponding + one-hot label tensor. + + :yields: Tuple of ``((char_tensor, word_emb, sentence_emb), label)``. + """ # Select sample for line_orig in self.texts: line = [ @@ -251,19 +374,30 @@ def parse_text(self): ), labels[ix] def gen_batch(self, text, stride=1): + """Build inference tensors for all diacritic-candidate characters in + *text*. + + Unlike :meth:`parse_text`, this method does not require ground-truth + labels and returns the tensors together with the absolute character + positions in *text* so predictions can later be mapped back. + + :param text: Raw input text to restore diacritics for. + :param stride: Step size when sliding the sentence window over the + token sequence. Values greater than 1 reduce computation at the + cost of receiving fewer predictions per character (which are + averaged during post-processing). + :return: Tuple of two aligned lists: + + * ``input_tensors`` – list of ``[char_tensor, word_emb, + sentence_emb]`` triples. + * ``character_indices`` – list of absolute character offsets into + *text* corresponding to each entry in *input_tensors*. """ - - :param text: input text to be processed - :param stride: generate sentence windows with this stride (you have to pool the results since you will - get more than 1 prediciton per character) - :return: - """ - text_plain = remove_diacritics(text).lower() - lines = PunktSentenceTokenizer().span_tokenize(text) + sentence_spans = self._sent_tokenizer.span_tokenize(text) character_indices = [] input_tensors = [] - for line_span in lines: + for line_span in sentence_spans: line = text[line_span[0] : line_span[1]] words = list(TreebankWordTokenizer().span_tokenize(line)) word_indices = np.arange(len(words)) @@ -304,10 +438,28 @@ def gen_batch(self, text, stride=1): @staticmethod def load_texts(data, min_line_length): + """Load and pre-process text data from a file or raw string. + + Supports three input formats: + + * ``.pkl`` file – a previously pickled ``(texts, counter)`` tuple is + loaded directly. + * Plain-text file – lines are sentence-tokenised and then + word-tokenised; lines shorter than *min_line_length* are dropped. + * Raw text string – same sentence/word tokenisation is applied inline. + + :param data: File path (str or :class:`~pathlib.Path`) or raw text. + :param min_line_length: Minimum character length for a sentence to be + kept. + :return: Tuple of ``(texts, distinct_no_diacritics)`` where *texts* is + a list of token lists and *distinct_no_diacritics* is a + :class:`~collections.Counter` of diacritics-free token frequencies. + """ if data and Path(data).exists(): filename = data if Path(filename).suffix.lower() == ".pkl": # loading cached pickle - texts, distinct_no_diacritics = pkl.load(open(filename, "rb")) + with open(filename, "rb") as f: + texts, distinct_no_diacritics = pkl.load(f) else: with open(filename, "r", encoding="utf-8") as f: texts = [ diff --git a/ro_diacritics/diacritics_inference.py b/ro_diacritics/diacritics_inference.py index 7c61ed2..ddd231d 100644 --- a/ro_diacritics/diacritics_inference.py +++ b/ro_diacritics/diacritics_inference.py @@ -1,3 +1,5 @@ +"""High-level inference helpers: model loading, downloading, and diacritics restoration.""" + import zipfile import logging from pathlib import Path @@ -28,6 +30,22 @@ ) +class _TensorListDataset(IterableDataset): + """Lightweight :class:`~torch.utils.data.IterableDataset` wrapper around a + plain list of tensors produced by :meth:`~DiacriticsDataset.gen_batch`. + + :param input_tensors: List of ``[char_tensor, word_emb, sentence_emb]`` + triples. + """ + + def __init__(self, input_tensors): + self.input_tensors = input_tensors + + def __iter__(self): + """Yield each tensor triple in order.""" + return iter(self.input_tensors) + + def get_cached_model(): """ Loads the cached model. If not, it tries to download the model from github @@ -63,18 +81,24 @@ def inner(b=1, bsize=1, tsize=None): return filename -def load_model(filename) -> (Diacritics, DiacriticsDataset): - """ - Loads a trained :class:Diacritics model from cached file. - Also, used hyperparams are loaded from the file - :param filename: local path to stored model - :return: loaded :class:Diacritics object and the used vocabulary (must be the same as in training) +def load_model(filename) -> tuple[Diacritics, DiacriticsDataset]: + """Load a trained :class:`~diacritics_model.Diacritics` model from a + checkpoint file. + + The checkpoint is expected to contain the model ``state_dict``, + hyper-parameters, a serialised :class:`~diacritics_dataset.DiacriticsVocab` + vocabulary, and optional evaluation metrics. + + :param filename: Path to the ``.pt`` checkpoint produced by + :meth:`~diacritics_model.Diacritics.save`. + :return: Tuple of the reconstructed :class:`~diacritics_model.Diacritics` + model (in eval mode) and the saved vocabulary object. """ import sys sys.modules["diacritics_dataset"] = diacritics_dataset - checkpoint = torch.load(filename, map_location="cpu") + checkpoint = torch.load(filename, map_location="cpu", weights_only=False) params = checkpoint["hyperparams"] model = Diacritics( nr_classes=params["nr_classes"], @@ -89,7 +113,7 @@ def load_model(filename) -> (Diacritics, DiacriticsDataset): ) model.load_state_dict(checkpoint["model_state"]) - checkpoint["valid_f1"] = checkpoint["valid_f1"] if "valid_f1" in checkpoint else 0 + checkpoint["valid_f1"] = checkpoint.get("valid_f1", 0) logger.info( f"Loaded checkpoint: Epoch: {checkpoint['epoch']}, valid_acc: {checkpoint['valid_acc']}, valid_f1: {checkpoint['valid_f1']}" ) @@ -98,12 +122,20 @@ def load_model(filename) -> (Diacritics, DiacriticsDataset): def initmodel(filename=None): + """Initialise the global model and dataset used by :func:`restore_diacritics`. + + Downloads the pre-trained checkpoint when *filename* is ``None`` and the + cached file does not exist yet. + + :param filename: Optional path to a local ``.pt`` checkpoint file. When + omitted the default cached model is used (downloaded on first call). + """ global _model, _dataset if filename is None: filename = get_cached_model() _model, vocab = load_model(filename) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if device == "cpu": + if device.type == "cpu": logger.warning("GPU not available, using CPU") else: logger.info("GPU available") @@ -127,14 +159,7 @@ def restore_diacritics(text, batch_size=128): input_tensors, character_indices = _dataset.gen_batch(text, stride=10) - class DS(IterableDataset): - def __init__(self, input_tensors): - self.input_tensors = input_tensors - - def __iter__(self): - return iter(self.input_tensors) - - input_data = DataLoader(DS(input_tensors), batch_size=batch_size) + input_data = DataLoader(_TensorListDataset(input_tensors), batch_size=batch_size) predictions = predict(_model, input_data) diff --git a/ro_diacritics/diacritics_model.py b/ro_diacritics/diacritics_model.py index 7d80462..14c510e 100644 --- a/ro_diacritics/diacritics_model.py +++ b/ro_diacritics/diacritics_model.py @@ -1,3 +1,5 @@ +"""PyTorch model definition for Romanian diacritics restoration.""" + import torch import torch.nn as nn @@ -146,6 +148,20 @@ def forward(self, char_input, word_embedding, sentence_embedding): return out def save(self, filename, vocabulary, epoch=None, valid_acc=0.0, valid_f1=0.0): + """Serialise the model weights, hyper-parameters and vocabulary to + *filename* using :func:`torch.save`. + + The resulting checkpoint can be loaded with + :func:`~diacritics_inference.load_model`. + + :param filename: Destination file path for the checkpoint. + :param vocabulary: The :class:`~diacritics_dataset.DiacriticsVocab` + instance used during training (saved alongside the model so that + inference can reuse the same embeddings and character mappings). + :param epoch: Training epoch at which this checkpoint was saved. + :param valid_acc: Validation accuracy achieved at this checkpoint. + :param valid_f1: Weighted F1 score achieved at this checkpoint. + """ to_save = { "epoch": epoch, "model_state": self.state_dict(), diff --git a/ro_diacritics/diacritics_utils.py b/ro_diacritics/diacritics_utils.py index f394af5..812bd28 100644 --- a/ro_diacritics/diacritics_utils.py +++ b/ro_diacritics/diacritics_utils.py @@ -1,3 +1,5 @@ +"""Utility constants and helper functions for Romanian diacritics processing.""" + import re LOG_NAME = "ro-diacritics" @@ -56,25 +58,51 @@ "t": ["ț", "t"], } +# Pre-compiled regex patterns for efficiency +_RE_CORRECT_DIACRITICS = re.compile( + "|".join(re.escape(k) for k in MAP_CORRECT_DIACRITICS) +) +_RE_REMOVE_DIACRITICS = re.compile( + "|".join(re.escape(k) for k in MAP_DIACRITICS) +) +_RE_INTERESTING_CHARS = re.compile( + "|".join(re.escape(c) for c in DIACRITICS_CANDIDATES) +) + def correct_diacritics(word): - # use these three lines to do the replacement - rep = dict((re.escape(k), v) for k, v in MAP_CORRECT_DIACRITICS.items()) - pattern = re.compile("|".join(rep.keys())) - return pattern.sub(lambda m: rep[re.escape(m.group(0))], str(word)) + """Replace old-style cedilla diacritics (ş, Ş, ţ, Ţ) with the correct + comma-below forms (ș, Ș, ț, Ț) used in modern Romanian orthography. + + :param word: Input string that may contain old-style diacritics. + :return: String with old-style diacritics replaced by correct forms. + """ + return _RE_CORRECT_DIACRITICS.sub( + lambda m: MAP_CORRECT_DIACRITICS[m.group(0)], str(word) + ) def remove_diacritics(word): - # use these three lines to do the replacement - rep = dict((re.escape(k), v) for k, v in MAP_DIACRITICS.items()) - pattern = re.compile("|".join(rep.keys())) - return pattern.sub(lambda m: rep[re.escape(m.group(0))], str(word)) + """Strip all Romanian diacritical characters from *word*, returning the + plain ASCII base form (e.g. "ș" → "s", "ă" → "a"). + + :param word: Input string that may contain Romanian diacritics. + :return: String with all diacritics removed. + """ + return _RE_REMOVE_DIACRITICS.sub( + lambda m: MAP_DIACRITICS[m.group(0)], str(word) + ) def has_interesting_chars(word): - # use these three lines to do the replacement - pattern = re.compile("|".join(DIACRITICS_CANDIDATES)) - return pattern.search(str(word)) is not None + """Return True if *word* contains at least one character that could carry + a Romanian diacritic (i.e. one of 'a', 'i', 's', 't'). + + :param word: Input string to check. + :return: True if a diacritic-candidate character is found, False otherwise. + """ + return _RE_INTERESTING_CHARS.search(str(word)) is not None + __all__ = [ "correct_diacritics", @@ -85,4 +113,5 @@ def has_interesting_chars(word): "DIACRITICS_CANDIDATES", "MAP_CORRECT_DIACRITICS", "MAP_POSSIBLE_CHARS", + "LOG_NAME", ] \ No newline at end of file diff --git a/setup.py b/setup.py index 39457c1..eff9517 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,14 @@ # package_dir={"": "."}, packages=find_packages(), - install_requires=['torch', 'torchtext', 'numpy', 'tqdm', 'nltk', 'scikit-learn',], + install_requires=[ + 'torch>=2.0', + 'fasttext-wheel', + 'numpy', + 'tqdm', + 'nltk', + 'scikit-learn', + ], zip_safe=False, )