|
| 1 | +"""Tests for the preprocessing module.""" |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import numpy.testing as npt |
| 5 | +import pytest |
| 6 | + |
| 7 | +from orca_python.preprocessing import normalize, preprocess_input, standardize |
| 8 | + |
| 9 | + |
| 10 | +@pytest.fixture |
| 11 | +def dataset(): |
| 12 | + """Create synthetic dataset for testing preprocessing functions.""" |
| 13 | + X_train = np.random.randn(100, 5) |
| 14 | + X_test = np.random.randn(50, 5) |
| 15 | + return X_train, X_test |
| 16 | + |
| 17 | + |
| 18 | +def test_normalize_data(dataset): |
| 19 | + """Test that normalize function correctly scales input data to [0,1] range.""" |
| 20 | + X_train, X_test = dataset |
| 21 | + norm_X_train, _ = normalize(X_train, X_test) |
| 22 | + assert np.all(norm_X_train >= 0) and np.all(norm_X_train <= 1) |
| 23 | + |
| 24 | + |
| 25 | +def test_standardize_data(dataset): |
| 26 | + """Test that standardize function correctly produces output with zero mean and unit variance.""" |
| 27 | + X_train, X_test = dataset |
| 28 | + std_X_train, _ = standardize(X_train, X_test) |
| 29 | + npt.assert_almost_equal(np.mean(std_X_train), 0, decimal=6) |
| 30 | + npt.assert_almost_equal(np.std(std_X_train), 1, decimal=6) |
| 31 | + |
| 32 | + |
| 33 | +@pytest.mark.parametrize( |
| 34 | + "input_preprocessing, method_func", |
| 35 | + [ |
| 36 | + ("norm", normalize), |
| 37 | + ("std", standardize), |
| 38 | + ], |
| 39 | +) |
| 40 | +def test_input_preprocessing(dataset, input_preprocessing, method_func): |
| 41 | + """Test that different preprocessing methods work as expected.""" |
| 42 | + X_train, X_test = dataset |
| 43 | + post_X_train, post_X_test = preprocess_input(X_train, X_test, input_preprocessing) |
| 44 | + expected_X_train, expected_X_test = method_func(X_train, X_test) |
| 45 | + npt.assert_array_almost_equal(post_X_train, expected_X_train) |
| 46 | + npt.assert_array_almost_equal(post_X_test, expected_X_test) |
| 47 | + |
| 48 | + |
| 49 | +def test_none_input_preprocessing(dataset): |
| 50 | + """Test that preprocessing function handles None input correctly.""" |
| 51 | + X_train, X_test = dataset |
| 52 | + post_X_train, post_X_test = preprocess_input(X_train, X_test, None) |
| 53 | + npt.assert_array_equal(post_X_train, X_train) |
| 54 | + npt.assert_array_equal(post_X_test, X_test) |
| 55 | + |
| 56 | + |
| 57 | +def test_input_preprocessing_unknown_method(dataset): |
| 58 | + """Test that an unknown preprocessing method raises an AttributeError.""" |
| 59 | + X_train, X_test = dataset |
| 60 | + error_msg = "Input preprocessing named 'esc' unknown" |
| 61 | + with pytest.raises(ValueError, match=error_msg): |
| 62 | + preprocess_input(X_train, X_test, "esc") |
| 63 | + |
| 64 | + |
| 65 | +def test_input_preprocessing_inconsistent_features(dataset): |
| 66 | + """Test that preprocessing with inconsistent feature dimensions raises error.""" |
| 67 | + X_train, X_test = dataset |
| 68 | + X_test = X_test[:, :-1] |
| 69 | + with pytest.raises(ValueError): |
| 70 | + preprocess_input(X_train, X_test, "norm") |
| 71 | + preprocess_input(X_train, X_test, "norm") |
0 commit comments