Skip to content

Commit e202a66

Browse files
committed
add more claude tests
1 parent defd1b8 commit e202a66

3 files changed

Lines changed: 189 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010
]
1111
description = "A collection of utility functions for MEG analysis."
1212
readme = "README.md"
13-
requires-python = ">=3.6"
13+
requires-python = ">=3.10"
1414
classifiers = [
1515
"Programming Language :: Python :: 3",
1616
"License :: OSI Approved :: MIT License",

tests/test_decoding.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,113 @@ def test_predict_proba_along_n_jobs(self):
2525
out2 = decoding.predict_proba_along(clf, X, axes=0, n_jobs=2)
2626
np.testing.assert_array_equal(out1, out2)
2727

28+
class TestStratify(unittest.TestCase):
29+
30+
def _make_imbalanced(self):
31+
rng = np.random.default_rng(0)
32+
X = rng.standard_normal((40, 5))
33+
y = np.array([0] * 30 + [1] * 10)
34+
return X, y
35+
36+
def test_undersample_balances(self):
37+
"""Undersampling reduces all classes to the minority count."""
38+
X, y = self._make_imbalanced()
39+
_, y_out = decoding.stratify(X, y, strategy='undersample', random_state=0)
40+
counts = np.bincount(y_out)
41+
np.testing.assert_array_equal(counts, [10, 10])
42+
43+
def test_oversample_balances(self):
44+
"""Oversampling raises all classes to the majority count."""
45+
X, y = self._make_imbalanced()
46+
_, y_out = decoding.stratify(X, y, strategy='oversample', random_state=0)
47+
counts = np.bincount(y_out)
48+
np.testing.assert_array_equal(counts, [30, 30])
49+
50+
def test_y_matches_X(self):
51+
"""Each output label matches the label of the corresponding input row."""
52+
X, y = self._make_imbalanced()
53+
X_out, y_out = decoding.stratify(X, y, strategy='undersample', random_state=0)
54+
for i in range(len(y_out)):
55+
orig_idx = np.where(np.all(X == X_out[i], axis=1))[0]
56+
self.assertTrue(len(orig_idx) > 0)
57+
self.assertEqual(y[orig_idx[0]], y_out[i])
58+
59+
def test_output_contains_original_rows(self):
60+
"""Every row in the output is an exact copy of an input row."""
61+
X, y = self._make_imbalanced()
62+
X_out, _ = decoding.stratify(X, y, strategy='undersample', random_state=0)
63+
for row in X_out:
64+
self.assertTrue(any(np.allclose(row, orig) for orig in X))
65+
66+
def test_invalid_strategy_raises(self):
67+
"""Unknown strategy string raises ValueError."""
68+
X, y = self._make_imbalanced()
69+
with self.assertRaises(ValueError):
70+
decoding.stratify(X, y, strategy='badstrat')
71+
72+
def test_y_2d_raises(self):
73+
"""Passing a 2-D array as y raises ValueError."""
74+
X, _ = self._make_imbalanced()
75+
y_2d = np.ones((40, 2))
76+
with self.assertRaises(ValueError):
77+
decoding.stratify(X, y_2d)
78+
79+
def test_random_state_reproducible(self):
80+
"""The same random seed produces identical outputs on repeated calls."""
81+
X, y = self._make_imbalanced()
82+
X1, y1 = decoding.stratify(X, y, strategy='undersample', random_state=42)
83+
X2, y2 = decoding.stratify(X, y, strategy='undersample', random_state=42)
84+
np.testing.assert_array_equal(y1, y2)
85+
np.testing.assert_array_equal(X1, X2)
86+
87+
def test_random_state_different(self):
88+
"""Different seeds produce different sample orderings."""
89+
rng = np.random.default_rng(1)
90+
X = rng.standard_normal((200, 5))
91+
y = np.array([0] * 150 + [1] * 50)
92+
_, y1 = decoding.stratify(X, y, strategy='undersample', random_state=0)
93+
_, y2 = decoding.stratify(X, y, strategy='undersample', random_state=99)
94+
self.assertFalse(np.array_equal(y1, y2))
95+
96+
def test_already_balanced(self):
97+
"""Both strategies leave an already-balanced dataset at the same size."""
98+
rng = np.random.default_rng(0)
99+
X = rng.standard_normal((20, 5))
100+
y = np.array([0] * 10 + [1] * 10)
101+
for strategy in ('undersample', 'oversample'):
102+
_, y_out = decoding.stratify(X, y, strategy=strategy, random_state=0)
103+
counts = np.bincount(y_out)
104+
np.testing.assert_array_equal(counts, [10, 10])
105+
106+
def test_oversample_no_replace_when_equal(self):
107+
"""Oversampling a balanced dataset does not duplicate any samples."""
108+
rng = np.random.default_rng(0)
109+
X = rng.standard_normal((20, 5))
110+
y = np.array([0] * 10 + [1] * 10)
111+
_, y_out = decoding.stratify(X, y, strategy='oversample', random_state=0)
112+
counts = np.bincount(y_out)
113+
np.testing.assert_array_equal(counts, [10, 10])
114+
115+
def test_works_with_dataframe(self):
116+
"""stratify accepts a pandas DataFrame and returns correct row counts."""
117+
import pandas as pd
118+
N = 40
119+
X = pd.DataFrame(np.eye(N))
120+
y = np.array([0] * 30 + [1] * 10)
121+
X_out, y_out = decoding.stratify(X, y, strategy='undersample', random_state=0)
122+
self.assertEqual(len(X_out), 20)
123+
self.assertEqual(len(y_out), 20)
124+
125+
def test_verbose_no_crash(self):
126+
"""verbose=True prints the target count without raising an error."""
127+
import io
128+
from contextlib import redirect_stdout
129+
X, y = self._make_imbalanced()
130+
buf = io.StringIO()
131+
with redirect_stdout(buf):
132+
decoding.stratify(X, y, strategy='undersample', random_state=0, verbose=True)
133+
self.assertIn('10', buf.getvalue())
134+
135+
28136
if __name__ == "__main__":
29137
unittest.main(verbosity=2)

tests/test_sigproc.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import matplotlib.pyplot as plt
1414
from meg_utils.sigproc import curves, fit_curve
1515
from meg_utils.sigproc import notch
16+
from meg_utils.sigproc import sliding_window
1617
from scipy.fft import rfft, rfftfreq
1718

1819
class TestFitCurve(unittest.TestCase):
@@ -98,5 +99,84 @@ def test_notch(self):
9899
out = notch(data, freqs=[50], sfreq=sfreq)
99100

100101

102+
class TestSlidingWindow(unittest.TestCase):
103+
104+
def test_output_shape_1d(self):
105+
"""Shape is (n_windows, win_size) for a 1-D input."""
106+
out = sliding_window(np.arange(10), win_size=4, stride=1)
107+
self.assertEqual(out.shape, (7, 4))
108+
109+
def test_output_shape_2d(self):
110+
"""Shape along the slid axis becomes n_windows, win_size is appended."""
111+
out = sliding_window(np.ones((3, 20)), win_size=5, stride=2)
112+
self.assertEqual(out.shape, (3, 8, 5))
113+
114+
def test_output_shape_axis0(self):
115+
"""Sliding along axis=0 replaces the first dimension with n_windows."""
116+
out = sliding_window(np.ones((20, 3)), win_size=5, stride=2, axis=0)
117+
self.assertEqual(out.shape, (8, 3, 5))
118+
119+
def test_output_shape_3d(self):
120+
"""Correct shape for a 3-D input sliding along the middle axis."""
121+
out = sliding_window(np.ones((2, 10, 4)), win_size=3, stride=2, axis=1)
122+
self.assertEqual(out.shape, (2, 4, 4, 3))
123+
124+
def test_window_values(self):
125+
"""Each window contains the correct consecutive elements."""
126+
out = sliding_window(np.arange(6), win_size=3, stride=1)
127+
np.testing.assert_array_equal(out[0], [0, 1, 2])
128+
np.testing.assert_array_equal(out[1], [1, 2, 3])
129+
np.testing.assert_array_equal(out[2], [2, 3, 4])
130+
np.testing.assert_array_equal(out[3], [3, 4, 5])
131+
132+
def test_stride_skips_correctly(self):
133+
"""stride > 1 advances the window start by that many elements."""
134+
out = sliding_window(np.arange(10), win_size=3, stride=3)
135+
np.testing.assert_array_equal(out[0], [0, 1, 2])
136+
np.testing.assert_array_equal(out[1], [3, 4, 5])
137+
np.testing.assert_array_equal(out[2], [6, 7, 8])
138+
139+
def test_is_view_shares_memory(self):
140+
"""Output shares memory with the input — no data copy is made."""
141+
arr = np.arange(100, dtype=float)
142+
out = sliding_window(arr, win_size=10, stride=1)
143+
self.assertTrue(np.shares_memory(arr, out))
144+
145+
def test_view_reflects_source_changes(self):
146+
"""Mutating the source array is immediately visible through the view."""
147+
arr = np.arange(20, dtype=float)
148+
out = sliding_window(arr, win_size=3, stride=1)
149+
arr[0] = 999.0
150+
self.assertEqual(out[0, 0], 999.0)
151+
152+
def test_read_only(self):
153+
"""Output is read-only; writing to it raises ValueError."""
154+
out = sliding_window(np.arange(10, dtype=float), win_size=3, stride=1)
155+
with self.assertRaises(ValueError):
156+
out[0, 0] = 99.0
157+
158+
def test_win_size_equals_length(self):
159+
"""win_size equal to axis length returns exactly one window."""
160+
arr = np.arange(5)
161+
out = sliding_window(arr, win_size=5, stride=1)
162+
self.assertEqual(out.shape, (1, 5))
163+
np.testing.assert_array_equal(out[0], arr)
164+
165+
def test_win_size_exceeds_length_raises(self):
166+
"""win_size larger than the axis length raises ValueError."""
167+
with self.assertRaises(ValueError):
168+
sliding_window(np.arange(5), win_size=6)
169+
170+
def test_invalid_stride_raises(self):
171+
"""stride <= 0 raises ValueError."""
172+
with self.assertRaises(ValueError):
173+
sliding_window(np.arange(10), win_size=3, stride=0)
174+
175+
def test_invalid_win_size_raises(self):
176+
"""win_size <= 0 raises ValueError."""
177+
with self.assertRaises(ValueError):
178+
sliding_window(np.arange(10), win_size=0)
179+
180+
101181
if __name__ == "__main__":
102182
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)