Skip to content

Commit 89168d5

Browse files
ENH: Create preprocessing module (#3)
* ENH: Initialize preprocessing module * ENH: Add preprocessing module with core functions * TST: Add preprocessing unit tests * STY: Pre-commit fixes
1 parent 05a5f41 commit 89168d5

File tree

10 files changed

+200
-6
lines changed

10 files changed

+200
-6
lines changed

.github/ISSUE_TEMPLATE/bug_report.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ body:
4848
Place traceback error here if applicable. If your issue has no traceback, please describe the observed output without formatting.
4949
```
5050
validations:
51-
required: true
51+
required: true

.github/ISSUE_TEMPLATE/doc_improvement.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ body:
1515
attributes:
1616
label: Suggest a potential alternative/fix
1717
description: >
18-
Tell us how you think the documentation could be improved.
18+
Tell us how you think the documentation could be improved.

.github/ISSUE_TEMPLATE/feature_request.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ body:
2727
attributes:
2828
label: Additional context
2929
description: >
30-
Add any other context about the problem here.
30+
Add any other context about the problem here.

.github/ISSUE_TEMPLATE/other_issue.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ body:
1818
label: Suggest a potential alternative/fix
1919
- type: textarea
2020
attributes:
21-
label: Additional context
21+
label: Additional context

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ A clear and concise description of what you have implemented.
1919

2020
<!--
2121
Please be aware that we are a team of volunteers so patience is necessary.
22-
-->
22+
-->

.github/workflows/pr_pytest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Run Tests
22

33
on:
44
push:
5-
branches:
5+
branches:
66
- main
77
paths:
88
- "orca_python/**"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Preprocessing module."""
2+
3+
from .preprocessing import (
4+
normalize,
5+
preprocess_input,
6+
standardize,
7+
)
8+
9+
__all__ = [
10+
"normalize",
11+
"preprocess_input",
12+
"standardize",
13+
]
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""Preprocessing module."""
2+
3+
from sklearn import preprocessing
4+
5+
6+
def preprocess_input(X_train, X_test=None, input_preprocessing=None):
7+
"""Apply normalization or standardization to the input data.
8+
9+
The preprocessing is fit on the training data and then applied to both
10+
training and test data (if provided).
11+
12+
Parameters
13+
----------
14+
X_train : np.ndarray
15+
Feature matrix used specifically for model training.
16+
17+
X_test : np.ndarray, optional
18+
Feature matrix used for model evaluation and prediction.
19+
20+
input_preprocessing : str, optional
21+
Data normalization strategy:
22+
- "norm": Linear scaling
23+
- "std": Standardization
24+
- None: No preprocessing
25+
26+
Returns
27+
-------
28+
X_train_scaled : np.ndarray
29+
Scaled training data.
30+
31+
X_test_scaled : np.ndarray, optional
32+
Scaled test data.
33+
34+
Raises
35+
------
36+
ValueError
37+
If an unknown preprocessing method is specified.
38+
39+
"""
40+
if input_preprocessing is None:
41+
return X_train, X_test
42+
43+
input_preprocessing = input_preprocessing.lower()
44+
if input_preprocessing == "norm":
45+
X_train_scaled, X_test_scaled = normalize(X_train, X_test)
46+
elif input_preprocessing == "std":
47+
X_train_scaled, X_test_scaled = standardize(X_train, X_test)
48+
else:
49+
raise ValueError(f"Input preprocessing named '{input_preprocessing}' unknown")
50+
51+
return X_train_scaled, X_test_scaled
52+
53+
54+
def normalize(X_train, X_test=None):
55+
"""Normalize the data.
56+
57+
Test data normalization will be based on train data.
58+
59+
Parameters
60+
----------
61+
X_train : np.ndarray
62+
Feature matrix used specifically for model training.
63+
64+
X_test : np.ndarray, optional
65+
Feature matrix used for model evaluation and prediction.
66+
67+
Returns
68+
-------
69+
X_train_normalized : np.ndarray
70+
Normalized training data.
71+
72+
X_test_normalized : np.ndarray, optional
73+
Normalized test data.
74+
75+
"""
76+
scaler = preprocessing.MinMaxScaler()
77+
X_train_normalized = scaler.fit_transform(X_train)
78+
X_test_normalized = scaler.transform(X_test) if X_test is not None else None
79+
return X_train_normalized, X_test_normalized
80+
81+
82+
def standardize(X_train, X_test=None):
83+
"""Standardize the data.
84+
85+
Test data standardization will be based on train data.
86+
87+
Parameters
88+
----------
89+
X_train : np.ndarray
90+
Feature matrix used specifically for model training.
91+
92+
X_test : np.ndarray, optional
93+
Feature matrix used for model evaluation and prediction.
94+
95+
Returns
96+
-------
97+
X_train_standardized : np.ndarray
98+
Standardized training data.
99+
100+
X_test_standardized : np.ndarray, optional
101+
Standardized test data.
102+
103+
"""
104+
scaler = preprocessing.StandardScaler()
105+
X_train_standardized = scaler.fit_transform(X_train)
106+
X_test_standardized = scaler.transform(X_test) if X_test is not None else None
107+
return X_train_standardized, X_test_standardized
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Tests for preprocessing module."""
2+
3+
__all__ = []
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Tests for the preprocessing module."""
2+
3+
import numpy as np
4+
import numpy.testing as npt
5+
import pytest
6+
7+
from orca_python.preprocessing import normalize, preprocess_input, standardize
8+
9+
10+
@pytest.fixture
11+
def dataset():
12+
"""Create synthetic dataset for testing preprocessing functions."""
13+
X_train = np.random.randn(100, 5)
14+
X_test = np.random.randn(50, 5)
15+
return X_train, X_test
16+
17+
18+
def test_normalize_data(dataset):
19+
"""Test that normalize function correctly scales input data to [0,1] range."""
20+
X_train, X_test = dataset
21+
norm_X_train, _ = normalize(X_train, X_test)
22+
assert np.all(norm_X_train >= 0) and np.all(norm_X_train <= 1)
23+
24+
25+
def test_standardize_data(dataset):
26+
"""Test that standardize function correctly produces output with zero mean and unit variance."""
27+
X_train, X_test = dataset
28+
std_X_train, _ = standardize(X_train, X_test)
29+
npt.assert_almost_equal(np.mean(std_X_train), 0, decimal=6)
30+
npt.assert_almost_equal(np.std(std_X_train), 1, decimal=6)
31+
32+
33+
@pytest.mark.parametrize(
34+
"input_preprocessing, method_func",
35+
[
36+
("norm", normalize),
37+
("std", standardize),
38+
],
39+
)
40+
def test_input_preprocessing(dataset, input_preprocessing, method_func):
41+
"""Test that different preprocessing methods work as expected."""
42+
X_train, X_test = dataset
43+
post_X_train, post_X_test = preprocess_input(X_train, X_test, input_preprocessing)
44+
expected_X_train, expected_X_test = method_func(X_train, X_test)
45+
npt.assert_array_almost_equal(post_X_train, expected_X_train)
46+
npt.assert_array_almost_equal(post_X_test, expected_X_test)
47+
48+
49+
def test_none_input_preprocessing(dataset):
50+
"""Test that preprocessing function handles None input correctly."""
51+
X_train, X_test = dataset
52+
post_X_train, post_X_test = preprocess_input(X_train, X_test, None)
53+
npt.assert_array_equal(post_X_train, X_train)
54+
npt.assert_array_equal(post_X_test, X_test)
55+
56+
57+
def test_input_preprocessing_unknown_method(dataset):
58+
"""Test that an unknown preprocessing method raises an AttributeError."""
59+
X_train, X_test = dataset
60+
error_msg = "Input preprocessing named 'esc' unknown"
61+
with pytest.raises(ValueError, match=error_msg):
62+
preprocess_input(X_train, X_test, "esc")
63+
64+
65+
def test_input_preprocessing_inconsistent_features(dataset):
66+
"""Test that preprocessing with inconsistent feature dimensions raises error."""
67+
X_train, X_test = dataset
68+
X_test = X_test[:, :-1]
69+
with pytest.raises(ValueError):
70+
preprocess_input(X_train, X_test, "norm")
71+
preprocess_input(X_train, X_test, "norm")

0 commit comments

Comments
 (0)