Skip to content

Commit bcfcd92

Browse files
committed
Esmote implement for aeon
1 parent bf628f4 commit bcfcd92

3 files changed

Lines changed: 343 additions & 1 deletion

File tree

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Supervised transformers to rebalance colelctions of time series."""
22

3-
__all__ = ["ADASYN", "SMOTE", "OHIT"]
3+
__all__ = ["ADASYN", "SMOTE", "OHIT", "ESMOTE"]
44

55
from aeon.transformations.collection.imbalance._adasyn import ADASYN
66
from aeon.transformations.collection.imbalance._ohit import OHIT
77
from aeon.transformations.collection.imbalance._smote import SMOTE
8+
from aeon.transformations.collection.imbalance._esmote import ESMOTE
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
from collections import OrderedDict
2+
from typing import Optional, Union
3+
4+
import numpy as np
5+
from numba import prange
6+
from sklearn.utils import check_random_state
7+
8+
from aeon.classification.distance_based import KNeighborsTimeSeriesClassifier
9+
from aeon.clustering.averaging._ba_utils import _get_alignment_path
10+
from aeon.transformations.collection import BaseCollectionTransformer
11+
12+
__all__ = ["ESMOTE"]
13+
14+
class KNN(KNeighborsTimeSeriesClassifier):
15+
"""
16+
KNN classifier for time series data, adapted to work with ESMOTE.
17+
This class is a wrapper around the original KNeighborsTimeSeriesClassifier
18+
to ensure compatibility with the ESMOTE transformation.
19+
"""
20+
21+
def _fit_setup(self, X, y):
22+
# KNN can support if all labels are the same so always return False for single
23+
# class problem so the fit will always run
24+
X, y, _ = super()._fit_setup(X, y)
25+
return X, y, False
26+
27+
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
28+
"""Find the K-neighbors of a point.
29+
30+
Returns indices of and distances to the neighbors of each point.
31+
32+
Parameters
33+
----------
34+
X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints) or list of
35+
shape [n_cases] of 2D arrays shape (n_channels,n_timepoints_i)
36+
The query point or points.
37+
If not provided, neighbors of each indexed point are returned.
38+
In this case, the query point is not considered its own neighbor.
39+
n_neighbors : int, default=None
40+
Number of neighbors required for each sample. The default is the value
41+
passed to the constructor.
42+
return_distance : bool, default=True
43+
Whether or not to return the distances.
44+
45+
Returns
46+
-------
47+
neigh_dist : ndarray of shape (n_queries, n_neighbors)
48+
Array representing the distances to points, only present if
49+
return_distance=True.
50+
neigh_ind : ndarray of shape (n_queries, n_neighbors)
51+
Indices of the nearest points in the population matrix.
52+
"""
53+
self._check_is_fitted()
54+
import numbers
55+
from aeon.distances import pairwise_distance
56+
if n_neighbors is None:
57+
n_neighbors = self.n_neighbors
58+
elif n_neighbors <= 0:
59+
raise ValueError(f"Expected n_neighbors > 0. Got {n_neighbors}")
60+
elif not isinstance(n_neighbors, numbers.Integral):
61+
raise TypeError(
62+
f"n_neighbors does not take {type(n_neighbors)} value, "
63+
"enter integer value"
64+
)
65+
66+
query_is_train = X is None
67+
if query_is_train:
68+
X = self.X_
69+
n_neighbors += 1
70+
else:
71+
X = self._preprocess_collection(X, store_metadata=False)
72+
self._check_shape(X)
73+
74+
# Compute pairwise distances between X and fit data
75+
distances = pairwise_distance(
76+
X,
77+
self.X_ if not query_is_train else None,
78+
method=self.distance,
79+
**self._distance_params,
80+
)
81+
82+
sample_range = np.arange(distances.shape[0])[:, None]
83+
neigh_ind = np.argpartition(distances, n_neighbors - 1, axis=1)
84+
neigh_ind = neigh_ind[:, :n_neighbors]
85+
neigh_ind = neigh_ind[
86+
sample_range, np.argsort(distances[sample_range, neigh_ind])
87+
]
88+
89+
if query_is_train:
90+
neigh_ind = neigh_ind[:, 1:]
91+
92+
if return_distance:
93+
if query_is_train:
94+
neigh_dist = distances[sample_range, neigh_ind]
95+
return neigh_dist, neigh_ind
96+
return distances[sample_range, neigh_ind], neigh_ind
97+
98+
return neigh_ind
99+
100+
101+
class ESMOTE(BaseCollectionTransformer):
102+
"""
103+
Elastic Synthetic Minority Over-sampling Technique (ESMOTE).
104+
Parameters
105+
----------
106+
n_neighbors : int, default=5
107+
The number of nearest neighbors used to define the neighborhood of samples
108+
to use to generate the synthetic time series.
109+
distance : str or callable, default="msm"
110+
The distance metric to use for the nearest neighbors search and alignment path
111+
of synthetic time series.
112+
weights : str or callable, default = 'uniform'
113+
Mechanism for weighting a vote one of: ``'uniform'``, ``'distance'``,
114+
or a callable
115+
function.
116+
random_state : int, RandomState instance or None, default=None
117+
If `int`, random_state is the seed used by the random number generator;
118+
If `RandomState` instance, random_state is the random number generator;
119+
If `None`, the random number generator is the `RandomState` instance used
120+
by `np.random`.
121+
See Also
122+
--------
123+
ADASYN
124+
References
125+
----------
126+
.. [1] Chawla et al. SMOTE: synthetic minority over-sampling technique, Journal
127+
of Artificial Intelligence Research 16(1): 321–357, 2002.
128+
https://dl.acm.org/doi/10.5555/1622407.1622416
129+
"""
130+
131+
_tags = {
132+
"capability:multivariate": False,
133+
"capability:unequal_length": False,
134+
"requires_y": True,
135+
}
136+
137+
def __init__(
138+
self,
139+
n_neighbors=5,
140+
distance: Union[str, callable] = "msm",
141+
distance_params: Optional[dict] = None,
142+
weights: Union[str, callable] = "uniform",
143+
n_jobs: int = 1,
144+
random_state=None,
145+
):
146+
self.random_state = random_state
147+
self.n_neighbors = n_neighbors
148+
self.distance = distance
149+
self.distance_params = distance_params
150+
self.weights = weights
151+
self.n_jobs = n_jobs
152+
153+
self._random_state = None
154+
self._distance_params = distance_params or {}
155+
156+
self.nn_ = None
157+
super().__init__()
158+
159+
def _fit(self, X, y=None):
160+
self._random_state = check_random_state(self.random_state)
161+
self.nn_ = KNN(
162+
n_neighbors=self.n_neighbors + 1,
163+
distance=self.distance,
164+
distance_params=self._distance_params,
165+
weights=self.weights,
166+
n_jobs=self.n_jobs,
167+
)
168+
169+
# generate sampling target by targeting all classes except the majority
170+
unique, counts = np.unique(y, return_counts=True)
171+
target_stats = dict(zip(unique, counts))
172+
n_sample_majority = max(target_stats.values())
173+
class_majority = max(target_stats, key=target_stats.get)
174+
sampling_strategy = {
175+
key: n_sample_majority - value
176+
for (key, value) in target_stats.items()
177+
if key != class_majority
178+
}
179+
self.sampling_strategy_ = OrderedDict(sorted(sampling_strategy.items()))
180+
return self
181+
182+
def _transform(self, X, y=None):
183+
X_resampled = [X.copy()]
184+
y_resampled = [y.copy()]
185+
186+
# got the minority class label and the number needs to be generated
187+
for class_sample, n_samples in self.sampling_strategy_.items():
188+
if n_samples == 0:
189+
continue
190+
target_class_indices = np.flatnonzero(y == class_sample)
191+
X_class = X[target_class_indices]
192+
y_class = y[target_class_indices]
193+
194+
self.nn_.fit(X_class, y_class)
195+
nns = self.nn_.kneighbors(X_class, return_distance=False)[:, 1:]
196+
X_new, y_new = self._make_samples(
197+
X_class,
198+
y.dtype,
199+
class_sample,
200+
X_class,
201+
nns,
202+
n_samples,
203+
1.0,
204+
n_jobs=self.n_jobs,
205+
)
206+
X_resampled.append(X_new)
207+
y_resampled.append(y_new)
208+
X_synthetic = np.vstack(X_resampled)
209+
y_synthetic = np.hstack(y_resampled)
210+
211+
return X_synthetic, y_synthetic
212+
213+
214+
def _make_samples(
215+
self, X, y_dtype, y_type, nn_data, nn_num, n_samples, step_size=1.0, n_jobs=1
216+
):
217+
samples_indices = self._random_state.randint(
218+
low=0, high=nn_num.size, size=n_samples
219+
)
220+
221+
steps = step_size * self._random_state.uniform(low=0, high=1, size=n_samples)[:, np.newaxis]
222+
rows = np.floor_divide(samples_indices, nn_num.shape[1])
223+
cols = np.mod(samples_indices, nn_num.shape[1])
224+
distance = self.distance
225+
226+
X_new = _generate_samples(
227+
X,
228+
nn_data,
229+
nn_num,
230+
rows,
231+
cols,
232+
steps,
233+
random_state=self._random_state,
234+
distance=distance,
235+
**self._distance_params,
236+
)
237+
y_new = np.full(n_samples, fill_value=y_type, dtype=y_dtype)
238+
return X_new, y_new
239+
240+
241+
def _generate_samples(
242+
X,
243+
nn_data,
244+
nn_num,
245+
rows,
246+
cols,
247+
steps,
248+
random_state,
249+
distance,
250+
weights: Optional[np.ndarray] = None,
251+
window: Union[float, None] = None,
252+
g: float = 0.0,
253+
epsilon: Union[float, None] = None,
254+
nu: float = 0.001,
255+
lmbda: float = 1.0,
256+
independent: bool = True,
257+
c: float = 1.0,
258+
descriptor: str = "identity",
259+
reach: int = 15,
260+
warp_penalty: float = 1.0,
261+
transformation_precomputed: bool = False,
262+
transformed_x: Optional[np.ndarray] = None,
263+
transformed_y: Optional[np.ndarray] = None,
264+
):
265+
X_new = np.zeros((len(rows), *X.shape[1:]), dtype=X.dtype)
266+
267+
for count in prange(len(rows)):
268+
i = rows[count]
269+
j = cols[count]
270+
curr_ts = X[i] # shape: (c, l)
271+
nn_ts = nn_data[nn_num[i, j]] # shape: (c, l)
272+
new_ts = curr_ts.copy()
273+
alignment, _ = _get_alignment_path(
274+
nn_ts,
275+
curr_ts,
276+
distance,
277+
window,
278+
g,
279+
epsilon,
280+
nu,
281+
lmbda,
282+
independent,
283+
c,
284+
descriptor,
285+
reach,
286+
warp_penalty,
287+
transformation_precomputed,
288+
transformed_x,
289+
transformed_y,
290+
)
291+
path_list = [[] for _ in range(curr_ts.shape[1])]
292+
for k, l in alignment:
293+
path_list[k].append(l)
294+
295+
# num_of_alignments = np.zeros_like(curr_ts, dtype=np.int32)
296+
empty_of_array = np.zeros_like(curr_ts, dtype=float) # shape: (c, l)
297+
298+
for k, l in enumerate(path_list):
299+
if len(l) == 0:
300+
print("No alignment found for time step")
301+
return new_ts
302+
303+
key = random_state.choice(l)
304+
# Compute difference for all channels at this time step
305+
empty_of_array[:, k] = curr_ts[:, k] - nn_ts[:, key]
306+
307+
X_new[count] = new_ts + steps[count] * empty_of_array #/ num_of_alignments
308+
309+
return X_new
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Test function for ESMOTE."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
from aeon.testing.data_generation import make_example_3d_numpy
7+
from aeon.transformations.collection.imbalance import ESMOTE
8+
9+
10+
def test_smote():
11+
"""Test the ESMOTE class.
12+
13+
This function creates a 3D numpy array, applies
14+
ESMOTE using the ESMOTE class, and asserts that the
15+
transformed data has a balanced number of samples.
16+
"""
17+
n_samples = 100 # Total number of labels
18+
majority_num = 90 # number of majority class
19+
minority_num = n_samples - majority_num # number of minority class
20+
21+
X = np.random.rand(n_samples, 1, 10)
22+
y = np.array([0] * majority_num + [1] * minority_num)
23+
24+
transformer = ESMOTE()
25+
transformer.fit(X, y)
26+
res_X, res_y = transformer.transform(X, y)
27+
_, res_count = np.unique(res_y, return_counts=True)
28+
29+
assert len(res_X) == 2 * majority_num
30+
assert len(res_y) == 2 * majority_num
31+
assert res_count[0] == majority_num
32+
assert res_count[1] == majority_num

0 commit comments

Comments
 (0)