๐ Customer Churn Prediction with LightGBM, CatBoost, TabTransformer, and AutoGluon-Tabular in Amazon SageMaker ๐
This repository contains a Jupyter notebook ๐ that demonstrates how to use machine learning (ML) ๐ง for the automated identification of unhappy customers ๐, also known as customer churn prediction. The notebook employs Amazon SageMaker's implementation of LightGBM ๐ก, CatBoost ๐ฑ, TabTransformer ๐, and AutoGluon-Tabular ๐ algorithms to train and host a customer churn prediction model using Amazon SageMaker's Automatic Model Tuning (AMT) ๐๏ธ.
Losing customers is costly for any business. ๐ธ Identifying unhappy customers early gives businesses a chance to offer incentives to stay. However, ML models rarely give perfect predictions. This notebook showcases how to incorporate the relative costs of prediction mistakes when determining the financial outcome of using ML.
The ML models used in this notebook are trained in two scenarios:
- Training a tabular model on the customer churn dataset with AMT. ๐
- Using the trained tabular model to perform inference, i.e., classifying new samples. ๐งฎ
In the end, the performance of the four models trained with AMT is compared on the same test data. ๐
The notebook contains the following sections:
-
Set Up: This part involves setting up the notebook and importing necessary libraries. ๐
-
Data Preparation and Visualization: This section deals with preparing the dataset for the machine learning models and visualizing the data for a better understanding of it. ๐ The data visualization includes histograms for each numeric feature. The dataset used is a publicly available set containing 5,000 records with 21 attributes describing the profile of a customer of an unknown US mobile operator. The data is processed and visualized to understand its distribution and its relation with the target variable 'Churn?'. This section also includes converting the target attribute to binary and moving it to the first column of the dataset to meet the requirements of SageMaker's built-in tabular algorithms.
-
Train A LightGBM Model with AMT: This part involves training a LightGBM model with AMT and also includes sub-sections on retrieving training artifacts, setting training parameters, starting the training process, deploying and running inference on the trained model, and evaluating the prediction results. ๐ฏ
-
Train A CatBoost model with AMT: Similar to the previous section, this part focuses on the CatBoost model. ๐ฑโ๐ป
-
Train A TabTransformer model with AMT: This section focuses on training a TabTransformer model with AMT. ๐
-
Train An AutoGluon-Tabular model: This section is about training an AutoGluon-Tabular model. ๐๏ธ
-
Compare Prediction Results of Four Trained Models on the Same Test Data: In the final section, the performance of the four trained models - LightGBM, CatBoost, TabTransformer, and AutoGluon-Tabular - is compared using the same test data. The comparison includes a plot of the confusion matrix for each model. ๐
This notebook uses the following machine learning models:
LightGBM is a gradient boosting framework that uses tree-based learning algorithms. It is designed to be distributed and efficient with the following advantages: faster training speed and higher efficiency, lower memory usage, better accuracy, parallel and GPU learning supported, capable of handling large-scale data.
CatBoost is an open-source gradient boosting on decision trees library with categorical features support. The CatBoost library can be used to solve both classification and regression tasks. The library provides all the functionality necessary to solve classification, regression and ranking tasks, supports GPU acceleration and handling categorical features.
TabTransformer is a model introduced in the paper "TabTransformer: Tabular Data Modeling Using Contextual Embeddings". It treats each feature as a token and applies a transformer model to capture the interactions between features. It's especially powerful for tabular data with high cardinality categorical features.
AutoGluon-Tabular is an automated machine learning model designed for tabular data. It automatically prepares the data, selects the necessary features, and chooses the best model and hyperparameters to solve the problem. AutoGluon-Tabular can save a lot of time and effort with great performance.
| LightGBM with AMT | CatBoost with AMT | TabTransformer with AMT | AutoGluon-Tabular | |
|---|---|---|---|---|
| Accuracy | 0.895556 | 0.953333 | 0.960000 | 0.980000 |
| F1 | 0.894855 | 0.953229 | 0.960526 | 0.980044 |
| AUC | 0.965412 | 0.991407 | 0.992040 | 0.997926 |
Each column represents a different model, and the rows represent different metrics (Accuracy, F1, and AUC) for each model. ๐
To avoid incurring unnecessary charges, remember to delete the endpoint corresponding to the trained model. Instructions on how to do this are provided at the end of the notebook. ๐งพ
You can use this notebook to evaluate the performance of LightGBM, CatBoost, TabTransformer, and AutoGluon-Tabular on your own dataset. The notebook was tested in Amazon SageMaker Studio on ml.t3.medium instance with Python 3 (Data Science) kernel. ๐
