欠損値を含むデータに対して、発生原因がわからなくても機械学習を行うことが出来る。
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
データ読み込み
df = pd.read_csv("train.csv")
Name カラムを除外
df = df.drop(columns=["Name"])
目的変数と説明変数に分ける
X = df.drop(columns=["Survived"])
y = df["Survived"]
カテゴリカル列を抽出
categorical_features = X.select_dtypes(include=["object"]).columns.tolist()
欠損値を "missing" に置換し、文字列化
for col in categorical_features:
X[col] = X[col].astype(str).fillna("missing")
学習用とテスト用に分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Pool を作成
train_pool = Pool(X_train, y_train, cat_features=categorical_features)
test_pool = Pool(X_test, y_test, cat_features=categorical_features)
モデル定義と学習
model = CatBoostClassifier(iterations=300, depth=6, learning_rate=0.1, loss_function="Logloss", verbose=False)
model.fit(train_pool)
予測と評価
y_pred = model.predict(test_pool)
print("Accuracy:", accuracy_score(y_test, y_pred))
欠損値を含むデータに対して、発生原因がわからなくても機械学習を行うことが出来る。
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
データ読み込み
df = pd.read_csv("train.csv")
Name カラムを除外
df = df.drop(columns=["Name"])
目的変数と説明変数に分ける
X = df.drop(columns=["Survived"])
y = df["Survived"]
カテゴリカル列を抽出
categorical_features = X.select_dtypes(include=["object"]).columns.tolist()
欠損値を "missing" に置換し、文字列化
for col in categorical_features:
X[col] = X[col].astype(str).fillna("missing")
学習用とテスト用に分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
Pool を作成
train_pool = Pool(X_train, y_train, cat_features=categorical_features)
test_pool = Pool(X_test, y_test, cat_features=categorical_features)
モデル定義と学習
model = CatBoostClassifier(iterations=300, depth=6, learning_rate=0.1, loss_function="Logloss", verbose=False)
model.fit(train_pool)
予測と評価
y_pred = model.predict(test_pool)
print("Accuracy:", accuracy_score(y_test, y_pred))