- Introduction
- Prerequisites
- Installation
- Usage
- Configuration
- Output
- Interpretation of Results
- Troubleshooting
- Privacy and Ethical Considerations
- Contributing
- License
This project implements a Tabular Generative Adversarial Network (TGAN) for synthesizing medical data. It's designed to generate realistic, synthetic medical data based on real input data in tabular format. The primary goal is to create high-quality synthetic data that maintains the statistical properties and relationships of the original data while ensuring patient privacy.
Key features:
- Data preprocessing and encoding
- CTGAN model training and synthetic data generation
- Comprehensive quality assessment of synthetic data
- Visualization of data distributions and correlations
- Machine learning utility testing
Before you begin, ensure you have met the following requirements:
- Python 3.7 or higher
- pip (Python package installer)
- Access to a command-line interface
- Basic understanding of Python and machine learning concepts
Follow these steps to set up the project environment:
-
Clone the repository or download the script:
git clone [repository-url] cd [repository-name] -
Automatically create the environment (if this is done, skip the next 2 steps) (only compatible with Ubuntu)
source ./create_env.sh -
(Optional but recommended) Create a virtual environment:
python -m venv ganvenv source ganvenv/bin/activate # On Windows, use `ganvenv\Scripts\activate` -
Install the required packages:
pip install -r gan_requirements.txtDetailed package versions:
- pandas (1.2.0 or higher)
- numpy (1.19.0 or higher)
- scikit-learn (0.24.0 or higher)
- seaborn (0.11.0 or higher)
- matplotlib (3.3.0 or higher)
- sdv (0.13.1 or higher)
-
If you plan to use GPU acceleration, ensure you have CUDA installed and install the appropriate PyTorch version.
To run the script:
-
Prepare your input data:
- Ensure your data is in CSV format with tab separation.
- The file should include columns for 'set/split' and 'finalsplit'.
- Remove any direct patient identifiers.
-
Update the script:
- Open
tabular_gan_medical_data.pyin a text editor. - Replace
'your_data.csv'with the path to your input data file. - If your target column is not 'AD', replace 'AD' in the
machine_learning_utility_testfunction call with your target column name.
- Open
-
Place splits files
- Place splits files (new_split_*) inside the same directory with the main script
-
Select phenotype
- Change the "PHENOTYPE" variable in line 29 of the main script with the phenotype of interest to make the splits accordingly
-
Create results directory
- Make sure there is a diretory called "distribution_comparison_Degree" in the same directory as the main script to be used for some of the results outputs, and if it doesn't exist, create one with the same name before proceeding to next step
-
Activate and run (if this is done, skip the next 2 steps) (only compatible with Ubuntu)
- Run the following script to activate the virtual environment and use it to automatically run the script by running: source activate_run_env.sh
-
Activate the environment
- activate the virtual environment by running:
source venv/bin/activate # On Windows, use
ganvenv\Scripts\activate
- activate the virtual environment by running:
source venv/bin/activate # On Windows, use
-
Run the script:
python tabular_gan_medical_data.py -
Review the output files and console messages for results and any error messages.
You can adjust the CTGAN model parameters at the top of the script:
EPOCHS: Number of training epochsBATCH_SIZE: Number of samples per batch during trainingGENERATOR_DIM: Dimensions of the generator networkDISCRIMINATOR_DIM: Dimensions of the discriminator networkGENERATOR_LRandDISCRIMINATOR_LR: Learning ratesDISCRIMINATOR_STEPS: Number of discriminator updates per generator updateEMBEDDING_DIM: Size of the random sample passed to the generatorCOMPRESS_DIMSandDECOMPRESS_DIMS: Dimensions of the encoder and decoderCUDA: Whether to use GPU acceleration (if available)
Adjust these parameters based on your data size, complexity, and computational resources.
The script generates several output files:
synthetic_medical_data.csv: The generated synthetic datacombined_medical_data.csv: Combined original and synthetic datactgan_medical_model.pkl: Saved CTGAN model for future use- Distribution comparison plots: PNG files for each numerical column
correlation_difference.png: Heatmap of correlation differences
Console output includes:
- SDV evaluation results
- Machine learning utility test results
-
SDV Evaluation Results:
- Provides metrics on how well the synthetic data matches the statistical properties of the real data.
- Look for high scores (closer to 1.0) indicating better quality.
-
Distribution Comparison Plots:
- Compare the shapes of real and synthetic data distributions.
- Look for similar overall shapes and ranges.
-
Correlation Difference Heatmap:
- Areas closer to white indicate well-preserved correlations.
- Red or blue areas show over- or under-represented correlations in synthetic data.
-
Machine Learning Utility Test:
- Compare accuracy and F1 scores between real and synthetic data.
- Synthetic data performance should ideally be close to real data performance.
Common issues and solutions:
- MemoryError: Reduce batch size or use a smaller subset of your data.
- CUDA out of memory: Reduce model dimensions or switch to CPU (set CUDA = False).
- Poor synthetic data quality: Try increasing epochs, adjusting network dimensions, or preprocessing data differently.
While this method helps create synthetic data, it's crucial to ensure:
- No personal identifiers are included in the input data.
- The synthetic data doesn't inadvertently reveal real patient information.
- Compliance with relevant data protection regulations (e.g., HIPAA, GDPR).
- Ethical use of the synthetic data, maintaining the same standards as for real patient data.
Contributions to improve the script or documentation are welcome. Please follow these steps:
- Fork the repository.
- Create a new branch for your feature.
- Commit your changes.
- Push to the branch.
- Create a new Pull Request.
[Specify the license under which this project is released, e.g., MIT, Apache 2.0, etc.]
For any questions or issues, please open an issue in the project repository.