@@ -25,5 +25,113 @@ def test_predict_proba_along_n_jobs(self):
2525 out2 = decoding .predict_proba_along (clf , X , axes = 0 , n_jobs = 2 )
2626 np .testing .assert_array_equal (out1 , out2 )
2727
28+ class TestStratify (unittest .TestCase ):
29+
30+ def _make_imbalanced (self ):
31+ rng = np .random .default_rng (0 )
32+ X = rng .standard_normal ((40 , 5 ))
33+ y = np .array ([0 ] * 30 + [1 ] * 10 )
34+ return X , y
35+
36+ def test_undersample_balances (self ):
37+ """Undersampling reduces all classes to the minority count."""
38+ X , y = self ._make_imbalanced ()
39+ _ , y_out = decoding .stratify (X , y , strategy = 'undersample' , random_state = 0 )
40+ counts = np .bincount (y_out )
41+ np .testing .assert_array_equal (counts , [10 , 10 ])
42+
43+ def test_oversample_balances (self ):
44+ """Oversampling raises all classes to the majority count."""
45+ X , y = self ._make_imbalanced ()
46+ _ , y_out = decoding .stratify (X , y , strategy = 'oversample' , random_state = 0 )
47+ counts = np .bincount (y_out )
48+ np .testing .assert_array_equal (counts , [30 , 30 ])
49+
50+ def test_y_matches_X (self ):
51+ """Each output label matches the label of the corresponding input row."""
52+ X , y = self ._make_imbalanced ()
53+ X_out , y_out = decoding .stratify (X , y , strategy = 'undersample' , random_state = 0 )
54+ for i in range (len (y_out )):
55+ orig_idx = np .where (np .all (X == X_out [i ], axis = 1 ))[0 ]
56+ self .assertTrue (len (orig_idx ) > 0 )
57+ self .assertEqual (y [orig_idx [0 ]], y_out [i ])
58+
59+ def test_output_contains_original_rows (self ):
60+ """Every row in the output is an exact copy of an input row."""
61+ X , y = self ._make_imbalanced ()
62+ X_out , _ = decoding .stratify (X , y , strategy = 'undersample' , random_state = 0 )
63+ for row in X_out :
64+ self .assertTrue (any (np .allclose (row , orig ) for orig in X ))
65+
66+ def test_invalid_strategy_raises (self ):
67+ """Unknown strategy string raises ValueError."""
68+ X , y = self ._make_imbalanced ()
69+ with self .assertRaises (ValueError ):
70+ decoding .stratify (X , y , strategy = 'badstrat' )
71+
72+ def test_y_2d_raises (self ):
73+ """Passing a 2-D array as y raises ValueError."""
74+ X , _ = self ._make_imbalanced ()
75+ y_2d = np .ones ((40 , 2 ))
76+ with self .assertRaises (ValueError ):
77+ decoding .stratify (X , y_2d )
78+
79+ def test_random_state_reproducible (self ):
80+ """The same random seed produces identical outputs on repeated calls."""
81+ X , y = self ._make_imbalanced ()
82+ X1 , y1 = decoding .stratify (X , y , strategy = 'undersample' , random_state = 42 )
83+ X2 , y2 = decoding .stratify (X , y , strategy = 'undersample' , random_state = 42 )
84+ np .testing .assert_array_equal (y1 , y2 )
85+ np .testing .assert_array_equal (X1 , X2 )
86+
87+ def test_random_state_different (self ):
88+ """Different seeds produce different sample orderings."""
89+ rng = np .random .default_rng (1 )
90+ X = rng .standard_normal ((200 , 5 ))
91+ y = np .array ([0 ] * 150 + [1 ] * 50 )
92+ _ , y1 = decoding .stratify (X , y , strategy = 'undersample' , random_state = 0 )
93+ _ , y2 = decoding .stratify (X , y , strategy = 'undersample' , random_state = 99 )
94+ self .assertFalse (np .array_equal (y1 , y2 ))
95+
96+ def test_already_balanced (self ):
97+ """Both strategies leave an already-balanced dataset at the same size."""
98+ rng = np .random .default_rng (0 )
99+ X = rng .standard_normal ((20 , 5 ))
100+ y = np .array ([0 ] * 10 + [1 ] * 10 )
101+ for strategy in ('undersample' , 'oversample' ):
102+ _ , y_out = decoding .stratify (X , y , strategy = strategy , random_state = 0 )
103+ counts = np .bincount (y_out )
104+ np .testing .assert_array_equal (counts , [10 , 10 ])
105+
106+ def test_oversample_no_replace_when_equal (self ):
107+ """Oversampling a balanced dataset does not duplicate any samples."""
108+ rng = np .random .default_rng (0 )
109+ X = rng .standard_normal ((20 , 5 ))
110+ y = np .array ([0 ] * 10 + [1 ] * 10 )
111+ _ , y_out = decoding .stratify (X , y , strategy = 'oversample' , random_state = 0 )
112+ counts = np .bincount (y_out )
113+ np .testing .assert_array_equal (counts , [10 , 10 ])
114+
115+ def test_works_with_dataframe (self ):
116+ """stratify accepts a pandas DataFrame and returns correct row counts."""
117+ import pandas as pd
118+ N = 40
119+ X = pd .DataFrame (np .eye (N ))
120+ y = np .array ([0 ] * 30 + [1 ] * 10 )
121+ X_out , y_out = decoding .stratify (X , y , strategy = 'undersample' , random_state = 0 )
122+ self .assertEqual (len (X_out ), 20 )
123+ self .assertEqual (len (y_out ), 20 )
124+
125+ def test_verbose_no_crash (self ):
126+ """verbose=True prints the target count without raising an error."""
127+ import io
128+ from contextlib import redirect_stdout
129+ X , y = self ._make_imbalanced ()
130+ buf = io .StringIO ()
131+ with redirect_stdout (buf ):
132+ decoding .stratify (X , y , strategy = 'undersample' , random_state = 0 , verbose = True )
133+ self .assertIn ('10' , buf .getvalue ())
134+
135+
28136if __name__ == "__main__" :
29137 unittest .main (verbosity = 2 )
0 commit comments