diff --git a/.github/workflows/model_test.yml b/.github/workflows/model_test.yml new file mode 100644 index 000000000..97cfadfea --- /dev/null +++ b/.github/workflows/model_test.yml @@ -0,0 +1,25 @@ +name: Model Tests + +on: + push: + branches: [ main, day5-homework ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f day5/requirements.txt ]; then pip install -r day5/requirements.txt; fi + - name: Run tests + run: | + cd day5/演習3 + pytest -v tests/ \ No newline at end of file diff --git "a/day5/\346\274\224\347\277\2223/tests/test_model.py" "b/day5/\346\274\224\347\277\2223/tests/test_model.py" index e11a19a5c..0132be895 100644 --- "a/day5/\346\274\224\347\277\2223/tests/test_model.py" +++ "b/day5/\346\274\224\347\277\2223/tests/test_model.py" @@ -171,3 +171,90 @@ def test_model_reproducibility(sample_data, preprocessor): assert np.array_equal( predictions1, predictions2 ), "モデルの予測結果に再現性がありません" + +def test_model_comparison_with_baseline(): + """現在のモデルをベースラインモデルと比較""" + # ベースラインモデルの読み込み(あらかじめ保存しておく必要あり) + baseline_model_path = os.path.join(os.path.dirname(__file__), "../baseline_models/baseline_model.pkl") + + if not os.path.exists(baseline_model_path): + pytest.skip("ベースラインモデルが存在しないためスキップします") + + # 現在のモデルの学習 + data = pd.read_csv(DATA_PATH) + X = data.drop("Survived", axis=1) + y = data["Survived"].astype(int) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + preprocessor = ColumnTransformer( + transformers=[ + ("num", Pipeline(steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]), ["Age", "Fare"]), + ("cat", Pipeline(steps=[("imputer", SimpleImputer(strategy="most_frequent")), ("onehot", OneHotEncoder(handle_unknown="ignore"))]), ["Pclass", "Sex", "Embarked"]), + ], + remainder="drop", + ) + + current_model = Pipeline( + steps=[ + ("preprocessor", preprocessor), + ("classifier", RandomForestClassifier(n_estimators=100, random_state=42)), + ] + ) + + current_model.fit(X_train, y_train) + + # ベースラインモデルの読み込み + with open(baseline_model_path, "rb") as f: + baseline_model = pickle.load(f) + + # 両方のモデルで予測 + baseline_pred = baseline_model.predict(X_test) + current_pred = current_model.predict(X_test) + + # 精度の比較 + baseline_accuracy = accuracy_score(y_test, baseline_pred) + current_accuracy = accuracy_score(y_test, current_pred) + + # 現在のモデルはベースラインと同等以上の性能であるべき + assert current_accuracy >= baseline_accuracy * 0.95, f"現在のモデル精度({current_accuracy:.4f})がベースライン({baseline_accuracy:.4f})より5%以上低下しています" + +def test_detailed_inference_time(): + """モデルの推論時間詳細テスト(バッチサイズ別)""" + # モデルとデータの準備 + data = pd.read_csv(DATA_PATH) + X = data.drop("Survived", axis=1) + y = data["Survived"].astype(int) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + preprocessor = ColumnTransformer( + transformers=[ + ("num", Pipeline(steps=[("imputer", SimpleImputer(strategy="median")), ("scaler", StandardScaler())]), ["Age", "Fare"]), + ("cat", Pipeline(steps=[("imputer", SimpleImputer(strategy="most_frequent")), ("onehot", OneHotEncoder(handle_unknown="ignore"))]), ["Pclass", "Sex", "Embarked"]), + ], + remainder="drop", + ) + + model = Pipeline( + steps=[ + ("preprocessor", preprocessor), + ("classifier", RandomForestClassifier(n_estimators=100, random_state=42)), + ] + ) + + model.fit(X_train, y_train) + + # 異なるバッチサイズでの推論時間テスト + batch_sizes = [1, 10, 50, 100] + for batch_size in batch_sizes: + # バッチサイズに合わせてデータを取得 + if batch_size <= len(X_test): + X_batch = X_test.iloc[:batch_size] + + # 推論時間計測 + start_time = time.time() + model.predict(X_batch) + inference_time = time.time() - start_time + + # バッチサイズ1の場合、0.01秒以内、その他は1秒以内であるべき + max_time = 0.01 if batch_size == 1 else 1.0 + assert inference_time < max_time, f"バッチサイズ{batch_size}での推論時間({inference_time:.4f}秒)が長すぎます" \ No newline at end of file