-
Notifications
You must be signed in to change notification settings - Fork 268
[ENH] Add flexible quality measures to RandomShapeletTransform #3244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| """Quality measures for shapelet evaluation. | ||
|
|
||
| This module contains numba-optimized quality measures for evaluating shapelet | ||
| quality in the RandomShapeletTransform. | ||
| """ | ||
|
|
||
| import numpy as np | ||
| from numba import njit | ||
|
|
||
|
|
||
| @njit(fastmath=True, cache=True) | ||
| def f_statistic(class0_distances, class1_distances): | ||
| """Calculate the F-statistic for shapelet quality. | ||
|
|
||
| The F-statistic measures the ratio of between-class variance to within-class | ||
| variance. Higher values indicate better class separation. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| class0_distances : np.ndarray | ||
| Array of distances for the first class. | ||
| class1_distances : np.ndarray | ||
| Array of distances for the second class. | ||
|
|
||
| Returns | ||
| ------- | ||
| float | ||
| The computed F-statistic. Returns np.inf if either class is empty or | ||
| if there are insufficient degrees of freedom. | ||
|
|
||
| Notes | ||
| ----- | ||
| The F-statistic is calculated as: | ||
| F = (SSB / df_between) / (SSW / df_within) | ||
| where SSB is the between-class sum of squares and SSW is the within-class | ||
| sum of squares. | ||
| """ | ||
| if len(class0_distances) == 0 or len(class1_distances) == 0: | ||
| return np.inf | ||
|
|
||
| # Calculate means | ||
| mean_class0 = np.mean(class0_distances) | ||
| mean_class1 = np.mean(class1_distances) | ||
| all_distances = np.concatenate((class0_distances, class1_distances)) | ||
| overall_mean = np.mean(all_distances) | ||
|
|
||
| n0 = len(class0_distances) | ||
| n1 = len(class1_distances) | ||
| total_n = n0 + n1 | ||
|
|
||
| # Between-class sum of squares | ||
| ssb = ( | ||
| n0 * (mean_class0 - overall_mean) ** 2 + n1 * (mean_class1 - overall_mean) ** 2 | ||
| ) | ||
|
|
||
| # Within-class sum of squares | ||
| ssw = np.sum((class0_distances - mean_class0) ** 2) + np.sum( | ||
| (class1_distances - mean_class1) ** 2 | ||
| ) | ||
|
|
||
| # Degrees of freedom | ||
| df_between = 1 | ||
| df_within = total_n - 2 | ||
|
|
||
| # Avoid division by zero | ||
| if df_within <= 0: | ||
| return np.inf | ||
|
|
||
| f_stat = (ssb / df_between) / (ssw / df_within) | ||
| return f_stat |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,9 @@ | |
| from sklearn.utils._random import check_random_state | ||
|
|
||
| from aeon.transformations.collection.base import BaseCollectionTransformer | ||
| from aeon.transformations.collection.shapelet_based._quality_measures import ( | ||
| f_statistic, | ||
| ) | ||
| from aeon.utils.numba.general import AEON_NUMBA_STD_THRESHOLD, z_normalise_series | ||
| from aeon.utils.validation import check_n_jobs | ||
|
|
||
|
|
@@ -37,13 +40,14 @@ class RandomShapeletTransform(BaseCollectionTransformer): | |
| For each candidate shapelet: | ||
| - Extract a shapelet from an instance with random length, position and | ||
| dimension and find its distance to each train case. | ||
| - Calculate the shapelet's information gain using the ordered list of | ||
| distances and train data class labels. | ||
| - Calculate the shapelet's quality using the selected quality measure | ||
| (information gain by default, or F-statistic) based on distances and | ||
| train data class labels. | ||
| - Abandon evaluating the shapelet if it is impossible to obtain a higher | ||
| information gain than the current worst. | ||
| quality than the current worst. | ||
| For each shapelet batch: | ||
| - Add each candidate to its classes shapelet heap, removing the lowest | ||
| information gain shapelet if the max number of shapelets has been met. | ||
| quality shapelet if the max number of shapelets has been met. | ||
| - Remove self-similar shapelets from the heap. | ||
| Using the final set of filtered shapelets, transform the data into a vector of | ||
| of distances from a series to each shapelet. | ||
|
|
@@ -52,7 +56,7 @@ class RandomShapeletTransform(BaseCollectionTransformer): | |
| ---------- | ||
| n_shapelet_samples : int, default=10000 | ||
| The number of candidate shapelets to be evaluated. Filtered down to | ||
| <= max_shapelets, keeping the shapelets with the most information gain. | ||
| <= max_shapelets, keeping the shapelets with the highest quality measure. | ||
| max_shapelets : int or None, default=None | ||
| Max number of shapelets to keep for the final transform. Each class value will | ||
| have its own max, set to n_classes / max_shapelets. If None uses the min between | ||
|
|
@@ -84,6 +88,10 @@ class RandomShapeletTransform(BaseCollectionTransformer): | |
| documentation for more details. | ||
| random_state : int or None, default=None | ||
| Seed for random number generation. | ||
| quality_measure : str, default="information_gain" | ||
| The quality measure to use for evaluating shapelets. Options are: | ||
| - "information_gain": Information gain (default, best accuracy but slower) | ||
| - "f_statistic": F-statistic (faster but may have lower accuracy) | ||
|
|
||
| Attributes | ||
| ---------- | ||
|
|
@@ -101,10 +109,10 @@ class RandomShapeletTransform(BaseCollectionTransformer): | |
| The stored shapelets and relating information after a dataset has been | ||
| processed. | ||
| Each item in the list is a tuple containing the following 7 items: | ||
| (shapelet information gain, shapelet length, start position the shapelet was | ||
| extracted from, shapelet dimension, index of the instance the shapelet was | ||
| extracted from in fit, class value of the shapelet, The z-normalised shapelet | ||
| array) | ||
| (shapelet quality measure value, shapelet length, start position the | ||
| shapelet was extracted from, shapelet dimension, index of the instance | ||
| the shapelet was extracted from in fit, class value of the shapelet, | ||
| The z-normalised shapelet array) | ||
|
|
||
| See Also | ||
| -------- | ||
|
|
@@ -164,6 +172,7 @@ def __init__( | |
| n_jobs: int = 1, | ||
| parallel_backend=None, | ||
| random_state: int | None = None, | ||
| quality_measure: str = "information_gain", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you put this above |
||
| ) -> None: | ||
| self.n_shapelet_samples = n_shapelet_samples | ||
| self.max_shapelets = max_shapelets | ||
|
|
@@ -180,6 +189,15 @@ def __init__( | |
| self.parallel_backend = parallel_backend | ||
| self.random_state = random_state | ||
|
|
||
| # Validate quality_measure | ||
| valid_measures = ["information_gain", "f_statistic"] | ||
| if quality_measure not in valid_measures: | ||
| raise ValueError( | ||
| f"quality_measure must be one of {valid_measures}, " | ||
| f"got {quality_measure}" | ||
| ) | ||
| self.quality_measure = quality_measure | ||
|
|
||
|
Comment on lines
+192
to
+200
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validation should be done in |
||
| # The following set in method fit | ||
| self.n_classes_ = 0 | ||
| self.n_cases_ = 0 | ||
|
|
@@ -493,18 +511,34 @@ def _extract_random_shapelet(self, X, y, i, rng): | |
| dtype=np.int32, | ||
| ) | ||
|
|
||
| quality, distances = self._find_shapelet_quality( | ||
| X, | ||
| y, | ||
| shapelet, | ||
| sorted_indicies, | ||
| position, | ||
| length, | ||
| channel, | ||
| inst_idx, | ||
| self._class_counts[cls_idx], | ||
| self.n_cases_ - self._class_counts[cls_idx], | ||
| ) | ||
| if self.quality_measure == "information_gain": | ||
| quality, distances = self._find_shapelet_quality( | ||
| X, | ||
| y, | ||
| shapelet, | ||
| sorted_indicies, | ||
| position, | ||
| length, | ||
| channel, | ||
| inst_idx, | ||
| self._class_counts[cls_idx], | ||
| self.n_cases_ - self._class_counts[cls_idx], | ||
| ) | ||
| elif self.quality_measure == "f_statistic": | ||
| quality, distances = self._find_shapelet_quality_f_stat( | ||
| X, | ||
| y, | ||
| shapelet, | ||
| sorted_indicies, | ||
| position, | ||
| length, | ||
| channel, | ||
| inst_idx, | ||
| self._class_counts[cls_idx], | ||
| self.n_cases_ - self._class_counts[cls_idx], | ||
| ) | ||
| else: | ||
| raise ValueError(f"Unknown quality measure: {self.quality_measure}") | ||
|
|
||
| return ( | ||
| List( | ||
|
|
@@ -553,6 +587,80 @@ def _find_shapelet_quality( | |
| orderline.sort() | ||
| return _calc_binary_ig(orderline, this_cls_count, other_cls_count), distances | ||
|
|
||
| @staticmethod | ||
| @njit(fastmath=True, cache=True) | ||
| def _find_shapelet_quality_f_stat( | ||
| X, | ||
|
Comment on lines
+590
to
+593
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is having a new function for each new quality measure necessary? Feel like this could be done better |
||
| y, | ||
| shapelet, | ||
| sorted_indicies, | ||
| position, | ||
| length, | ||
| dim, | ||
| inst_idx, | ||
| this_cls_count, | ||
| other_cls_count, | ||
| ): | ||
| """Find shapelet quality using F-statistic. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : list of np.ndarray | ||
| The training input samples. | ||
| y : np.ndarray | ||
| The class labels. | ||
| shapelet : np.ndarray | ||
| The shapelet to evaluate. | ||
| sorted_indicies : np.ndarray | ||
| Sorted indices for early abandon optimization. | ||
| position : int | ||
| Position of the shapelet in the original series. | ||
| length : int | ||
| Length of the shapelet. | ||
| dim : int | ||
| Dimension/channel index. | ||
| inst_idx : int | ||
| Index of the instance the shapelet was extracted from. | ||
| this_cls_count : int | ||
| Number of cases in the same class as the shapelet. | ||
| other_cls_count : int | ||
| Number of cases in other classes. | ||
|
|
||
| Returns | ||
| ------- | ||
| quality : float | ||
| The F-statistic quality measure. | ||
| distances : np.ndarray | ||
| Array of distances from each series to the shapelet. | ||
| """ | ||
| distances = np.zeros(len(X)) | ||
| distances1 = np.zeros(this_cls_count - 1) # -1 because we exclude inst_idx | ||
| distances2 = np.zeros(other_cls_count) | ||
| c1 = 0 | ||
| c2 = 0 | ||
|
|
||
| for i, series in enumerate(X): | ||
| if i != inst_idx: | ||
| distance = _online_shapelet_distance( | ||
| series[dim], shapelet, sorted_indicies, position, length | ||
| ) | ||
| distances[i] = distance | ||
|
|
||
| if y[i] == y[inst_idx]: | ||
| if c1 < len(distances1): | ||
| distances1[c1] = distance | ||
| c1 += 1 | ||
| else: | ||
| if c2 < len(distances2): | ||
| distances2[c2] = distance | ||
| c2 += 1 | ||
| else: | ||
| distances[i] = 0.0 | ||
|
|
||
| # Calculate F-statistic | ||
| quality = f_statistic(distances1[:c1], distances2[:c2]) | ||
| return quality, distances | ||
|
|
||
| @staticmethod | ||
| @njit(fastmath=True, cache=True) | ||
| def _merge_shapelets( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this should be in
utils/numbaand have its own tests