This repository contains a fine-tuned BERT-based model for detecting whether two sets of Korean characters are part of the same word or separate words. This is particularly useful for Optical Character Recognition (OCR) tasks involving Korean text, where line wrapping can split words in ways that affect downstream processing.
This project fine-tunes a HuggingFace model, lassl/bert-ko-small, for a binary classification task. It receives two sets of Korean characters separated by whitespace (e.g., "한 마리") and predicts one of two states:
- State 0: The characters belong to the same word.
- State 1: The characters belong to different words.
The model has been extended with:
- A custom forward pass.
- A linear classifier layer for predicting binary labels.
A synthetic dataset was generated using the HuggingFace dataset wikimedia/wikipedia. The synthetic dataset mimics scenarios where Korean text may wrap across lines or be split in ways that require word boundary detection.
The code for generating this dataset is included in generate_dataset.py.
The training code is fully provided, including:
- Data loading.
- Model fine-tuning.
- Hyperparameters used for optimization.
This project uses the following major libraries:
- Python 3.8+
- Pandas
- PyTorch
- AWS SDK for Python (Boto3)
- Amazon SageMaker SDK
- tqdm
- HuggingFace Transformers
- HuggingFace Datasets
Install the required dependencies with:
pip install -r requirements.txtThe generate_dataset.py script generates a dataset of synthetic Korean text samples. It:
- Extracts text from the Wikipedia dataset.
- Simulates scenarios where Korean text wraps mid-word.
- Replaces numbers with a special token
<N>.
The synthetic dataset is generated by iterating over text in the wikimedia/wikipedia dataset, labeling adjacent words as 1 (different words) and labeling words that are randomly split into two sets of characters as 0 (same word). The occurrences of each are counted then randomly sampled from during training to reflect the natural distribution of language in the original dataset. Strings that occur 100 times or less are excluded from the synthetic dataset.
For example, the small corpus "고양이? 고양이 한 마리" could create the following dataset:
| text | state | count |
|---|---|---|
| "고양이 고양이 | 1 | 1 |
| "고양이 한" | 1 | 1 |
| "한 마리" | 1 | 1 |
| "고양 이" | 0 | 2 |
| "마 리" | 0 | 1 |
To generate the synthetic training dataset run:
python generate_dataset.pyTo train the model, clone the repository, create a training dataset, and execute the training script:
python generate_dataset.py
python train.pyThese scripts will:
- Generate the synthetic dataset.
- Equally samples
Ntraining samples of both states (See Hyperparameters for more information). - Fine-tune the model using the specified hyperparameters.
- Save the fine-tuned weights and tokenizer.
Note: You can modify the hyperparameters by creating a hyperparameters.json file. Ensure the hyperparameter names match those in the train.Hyperparameters dataclass.
To use the trained model for inference:
-
Clone the repository and load the model and tokenizer:
import torch from model import BertForWordBoundaryDetection if torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = torch.device("mps") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") bert = BertForWordBoundaryDetection() bert.load_state_dict(torch.load("pytorch_model.bin")) bert.to(device)
-
Tokenize your input text and pass it through the model:
inputs = bert.tokenize_function(input_text) outputs = bert( input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device), ) predictions = (torch.sigmoid(outputs) > 0.5).int()
An example inference script is included in infer.py, which also includes a preprocess function and an example of running inference on an OCR output string. The example string used in infer.py is a page from The Whale (고래) by Cheon Myeong-kwan (천명관)
The fine-tuned model is based on lassl/bert-ko-small and includes:
- A custom linear classifier appended to the BERT encoder.
- A forward pass that extracts the
[CLS]token's hidden state and passes it through the classifier.
More details about the architecture and implementation can be found in the model.py file.
The following hyperparameters were used during training:
- Learning Rate: 5e-5
- Batch Size: 8
- Max Sequence Length: 6
- Number of Epochs: 10
- Test Split Size: 0.2
- N: 104
To modify hyperparameters, edit the train.py script.
Max Sequence Length was calculated as the 95th percentile of tokenized sequence lengths in the full processed dataset. The code for calculating the max sequence length can be found in the calculate_seq_length.py script.
The hyperparameter N is the number of samples drawn from the full processed dataset on each training epoch. N and Number of Epochs were chosen as 104 and 10 to favor more training iterations with smaller, more varied batches of training data.
This repository includes an example Jupyter notebook deploy.ipynb for deploying this model for use on Amazon SageMaker AI. To deploy this model, create a Jupyter notebook on the Amazon SageMaker AI console and clone this repository:
- From the Create notebook instance menu, select Default repository under Git repositories.
- From the Repository dropdown menu, select Clone a public Git repository to this notebook instance only.
- Enter this repository's URL under Git repository URL.
When your notebook instance is deployed, run all cells in deploy.ipynb. The script contains detailed information about deployment.
Contributions are welcome! If you'd like to contribute, please:
- Fork this repository.
- Submit a pull request with your changes.
If you encounter any issues or have feature requests, feel free to open an issue.
This project uses:
Special thanks to the creators of lassl/bert-ko-small for providing the pre-trained model.
This repository is open-source and available under the MIT License. See the LICENSE file for details.
- Incorporate more robust datasets for training.
- Evaluate the model's performance on real-world OCR tasks.
- Optimize the model for deployment on edge devices.