Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ body:
Place traceback error here if applicable. If your issue has no traceback, please describe the observed output without formatting.
```
validations:
required: true
required: true
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/doc_improvement.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ body:
attributes:
label: Suggest a potential alternative/fix
description: >
Tell us how you think the documentation could be improved.
Tell us how you think the documentation could be improved.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/feature_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ body:
attributes:
label: Additional context
description: >
Add any other context about the problem here.
Add any other context about the problem here.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/other_issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ body:
label: Suggest a potential alternative/fix
- type: textarea
attributes:
label: Additional context
label: Additional context
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ A clear and concise description of what you have implemented.

<!--
Please be aware that we are a team of volunteers so patience is necessary.
-->
-->
2 changes: 1 addition & 1 deletion .github/workflows/pr_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Run Tests

on:
push:
branches:
branches:
- main
paths:
- "orca_python/**"
Expand Down
13 changes: 13 additions & 0 deletions orca_python/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Preprocessing module."""

from .preprocessing import (
normalize,
preprocess_input,
standardize,
)

__all__ = [
"normalize",
"preprocess_input",
"standardize",
]
107 changes: 107 additions & 0 deletions orca_python/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Preprocessing module."""

from sklearn import preprocessing


def preprocess_input(X_train, X_test=None, input_preprocessing=None):
"""Apply normalization or standardization to the input data.

The preprocessing is fit on the training data and then applied to both
training and test data (if provided).

Parameters
----------
X_train : np.ndarray
Feature matrix used specifically for model training.

X_test : np.ndarray, optional
Feature matrix used for model evaluation and prediction.

input_preprocessing : str, optional
Data normalization strategy:
- "norm": Linear scaling
- "std": Standardization
- None: No preprocessing

Returns
-------
X_train_scaled : np.ndarray
Scaled training data.

X_test_scaled : np.ndarray, optional
Scaled test data.

Raises
------
ValueError
If an unknown preprocessing method is specified.

"""
if input_preprocessing is None:
return X_train, X_test

input_preprocessing = input_preprocessing.lower()
if input_preprocessing == "norm":
X_train_scaled, X_test_scaled = normalize(X_train, X_test)
elif input_preprocessing == "std":
X_train_scaled, X_test_scaled = standardize(X_train, X_test)
else:
raise ValueError(f"Input preprocessing named '{input_preprocessing}' unknown")

return X_train_scaled, X_test_scaled


def normalize(X_train, X_test=None):
"""Normalize the data.

Test data normalization will be based on train data.

Parameters
----------
X_train : np.ndarray
Feature matrix used specifically for model training.

X_test : np.ndarray, optional
Feature matrix used for model evaluation and prediction.

Returns
-------
X_train_normalized : np.ndarray
Normalized training data.

X_test_normalized : np.ndarray, optional
Normalized test data.

"""
scaler = preprocessing.MinMaxScaler()
X_train_normalized = scaler.fit_transform(X_train)
X_test_normalized = scaler.transform(X_test) if X_test is not None else None
return X_train_normalized, X_test_normalized


def standardize(X_train, X_test=None):
"""Standardize the data.

Test data standardization will be based on train data.

Parameters
----------
X_train : np.ndarray
Feature matrix used specifically for model training.

X_test : np.ndarray, optional
Feature matrix used for model evaluation and prediction.

Returns
-------
X_train_standardized : np.ndarray
Standardized training data.

X_test_standardized : np.ndarray, optional
Standardized test data.

"""
scaler = preprocessing.StandardScaler()
X_train_standardized = scaler.fit_transform(X_train)
X_test_standardized = scaler.transform(X_test) if X_test is not None else None
return X_train_standardized, X_test_standardized
3 changes: 3 additions & 0 deletions orca_python/preprocessing/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Tests for preprocessing module."""

__all__ = []
71 changes: 71 additions & 0 deletions orca_python/preprocessing/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Tests for the preprocessing module."""

import numpy as np
import numpy.testing as npt
import pytest

from orca_python.preprocessing import normalize, preprocess_input, standardize


@pytest.fixture
def dataset():
"""Create synthetic dataset for testing preprocessing functions."""
X_train = np.random.randn(100, 5)
X_test = np.random.randn(50, 5)
return X_train, X_test


def test_normalize_data(dataset):
"""Test that normalize function correctly scales input data to [0,1] range."""
X_train, X_test = dataset
norm_X_train, _ = normalize(X_train, X_test)
assert np.all(norm_X_train >= 0) and np.all(norm_X_train <= 1)


def test_standardize_data(dataset):
"""Test that standardize function correctly produces output with zero mean and unit variance."""
X_train, X_test = dataset
std_X_train, _ = standardize(X_train, X_test)
npt.assert_almost_equal(np.mean(std_X_train), 0, decimal=6)
npt.assert_almost_equal(np.std(std_X_train), 1, decimal=6)


@pytest.mark.parametrize(
"input_preprocessing, method_func",
[
("norm", normalize),
("std", standardize),
],
)
def test_input_preprocessing(dataset, input_preprocessing, method_func):
"""Test that different preprocessing methods work as expected."""
X_train, X_test = dataset
post_X_train, post_X_test = preprocess_input(X_train, X_test, input_preprocessing)
expected_X_train, expected_X_test = method_func(X_train, X_test)
npt.assert_array_almost_equal(post_X_train, expected_X_train)
npt.assert_array_almost_equal(post_X_test, expected_X_test)


def test_none_input_preprocessing(dataset):
"""Test that preprocessing function handles None input correctly."""
X_train, X_test = dataset
post_X_train, post_X_test = preprocess_input(X_train, X_test, None)
npt.assert_array_equal(post_X_train, X_train)
npt.assert_array_equal(post_X_test, X_test)


def test_input_preprocessing_unknown_method(dataset):
"""Test that an unknown preprocessing method raises an AttributeError."""
X_train, X_test = dataset
error_msg = "Input preprocessing named 'esc' unknown"
with pytest.raises(ValueError, match=error_msg):
preprocess_input(X_train, X_test, "esc")


def test_input_preprocessing_inconsistent_features(dataset):
"""Test that preprocessing with inconsistent feature dimensions raises error."""
X_train, X_test = dataset
X_test = X_test[:, :-1]
with pytest.raises(ValueError):
preprocess_input(X_train, X_test, "norm")
preprocess_input(X_train, X_test, "norm")