Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this should be in utils/numba and have its own tests

Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Quality measures for shapelet evaluation.

This module contains numba-optimized quality measures for evaluating shapelet
quality in the RandomShapeletTransform.
"""

import numpy as np
from numba import njit


@njit(fastmath=True, cache=True)
def f_statistic(class0_distances, class1_distances):
"""Calculate the F-statistic for shapelet quality.

The F-statistic measures the ratio of between-class variance to within-class
variance. Higher values indicate better class separation.

Parameters
----------
class0_distances : np.ndarray
Array of distances for the first class.
class1_distances : np.ndarray
Array of distances for the second class.

Returns
-------
float
The computed F-statistic. Returns np.inf if either class is empty or
if there are insufficient degrees of freedom.

Notes
-----
The F-statistic is calculated as:
F = (SSB / df_between) / (SSW / df_within)
where SSB is the between-class sum of squares and SSW is the within-class
sum of squares.
"""
if len(class0_distances) == 0 or len(class1_distances) == 0:
return np.inf

# Calculate means
mean_class0 = np.mean(class0_distances)
mean_class1 = np.mean(class1_distances)
all_distances = np.concatenate((class0_distances, class1_distances))
overall_mean = np.mean(all_distances)

n0 = len(class0_distances)
n1 = len(class1_distances)
total_n = n0 + n1

# Between-class sum of squares
ssb = (
n0 * (mean_class0 - overall_mean) ** 2 + n1 * (mean_class1 - overall_mean) ** 2
)

# Within-class sum of squares
ssw = np.sum((class0_distances - mean_class0) ** 2) + np.sum(
(class1_distances - mean_class1) ** 2
)

# Degrees of freedom
df_between = 1
df_within = total_n - 2

# Avoid division by zero
if df_within <= 0:
return np.inf

f_stat = (ssb / df_between) / (ssw / df_within)
return f_stat
150 changes: 129 additions & 21 deletions aeon/transformations/collection/shapelet_based/_shapelet_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from sklearn.utils._random import check_random_state

from aeon.transformations.collection.base import BaseCollectionTransformer
from aeon.transformations.collection.shapelet_based._quality_measures import (
f_statistic,
)
from aeon.utils.numba.general import AEON_NUMBA_STD_THRESHOLD, z_normalise_series
from aeon.utils.validation import check_n_jobs

Expand All @@ -37,13 +40,14 @@ class RandomShapeletTransform(BaseCollectionTransformer):
For each candidate shapelet:
- Extract a shapelet from an instance with random length, position and
dimension and find its distance to each train case.
- Calculate the shapelet's information gain using the ordered list of
distances and train data class labels.
- Calculate the shapelet's quality using the selected quality measure
(information gain by default, or F-statistic) based on distances and
train data class labels.
- Abandon evaluating the shapelet if it is impossible to obtain a higher
information gain than the current worst.
quality than the current worst.
For each shapelet batch:
- Add each candidate to its classes shapelet heap, removing the lowest
information gain shapelet if the max number of shapelets has been met.
quality shapelet if the max number of shapelets has been met.
- Remove self-similar shapelets from the heap.
Using the final set of filtered shapelets, transform the data into a vector of
of distances from a series to each shapelet.
Expand All @@ -52,7 +56,7 @@ class RandomShapeletTransform(BaseCollectionTransformer):
----------
n_shapelet_samples : int, default=10000
The number of candidate shapelets to be evaluated. Filtered down to
<= max_shapelets, keeping the shapelets with the most information gain.
<= max_shapelets, keeping the shapelets with the highest quality measure.
max_shapelets : int or None, default=None
Max number of shapelets to keep for the final transform. Each class value will
have its own max, set to n_classes / max_shapelets. If None uses the min between
Expand Down Expand Up @@ -84,6 +88,10 @@ class RandomShapeletTransform(BaseCollectionTransformer):
documentation for more details.
random_state : int or None, default=None
Seed for random number generation.
quality_measure : str, default="information_gain"
The quality measure to use for evaluating shapelets. Options are:
- "information_gain": Information gain (default, best accuracy but slower)
- "f_statistic": F-statistic (faster but may have lower accuracy)

Attributes
----------
Expand All @@ -101,10 +109,10 @@ class RandomShapeletTransform(BaseCollectionTransformer):
The stored shapelets and relating information after a dataset has been
processed.
Each item in the list is a tuple containing the following 7 items:
(shapelet information gain, shapelet length, start position the shapelet was
extracted from, shapelet dimension, index of the instance the shapelet was
extracted from in fit, class value of the shapelet, The z-normalised shapelet
array)
(shapelet quality measure value, shapelet length, start position the
shapelet was extracted from, shapelet dimension, index of the instance
the shapelet was extracted from in fit, class value of the shapelet,
The z-normalised shapelet array)

See Also
--------
Expand Down Expand Up @@ -164,6 +172,7 @@ def __init__(
n_jobs: int = 1,
parallel_backend=None,
random_state: int | None = None,
quality_measure: str = "information_gain",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you put this above batch_size, remember to do the docstring entry as well

) -> None:
self.n_shapelet_samples = n_shapelet_samples
self.max_shapelets = max_shapelets
Expand All @@ -180,6 +189,15 @@ def __init__(
self.parallel_backend = parallel_backend
self.random_state = random_state

# Validate quality_measure
valid_measures = ["information_gain", "f_statistic"]
if quality_measure not in valid_measures:
raise ValueError(
f"quality_measure must be one of {valid_measures}, "
f"got {quality_measure}"
)
self.quality_measure = quality_measure

Comment on lines +192 to +200
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation should be done in _fit or a function called from it

# The following set in method fit
self.n_classes_ = 0
self.n_cases_ = 0
Expand Down Expand Up @@ -493,18 +511,34 @@ def _extract_random_shapelet(self, X, y, i, rng):
dtype=np.int32,
)

quality, distances = self._find_shapelet_quality(
X,
y,
shapelet,
sorted_indicies,
position,
length,
channel,
inst_idx,
self._class_counts[cls_idx],
self.n_cases_ - self._class_counts[cls_idx],
)
if self.quality_measure == "information_gain":
quality, distances = self._find_shapelet_quality(
X,
y,
shapelet,
sorted_indicies,
position,
length,
channel,
inst_idx,
self._class_counts[cls_idx],
self.n_cases_ - self._class_counts[cls_idx],
)
elif self.quality_measure == "f_statistic":
quality, distances = self._find_shapelet_quality_f_stat(
X,
y,
shapelet,
sorted_indicies,
position,
length,
channel,
inst_idx,
self._class_counts[cls_idx],
self.n_cases_ - self._class_counts[cls_idx],
)
else:
raise ValueError(f"Unknown quality measure: {self.quality_measure}")

return (
List(
Expand Down Expand Up @@ -553,6 +587,80 @@ def _find_shapelet_quality(
orderline.sort()
return _calc_binary_ig(orderline, this_cls_count, other_cls_count), distances

@staticmethod
@njit(fastmath=True, cache=True)
def _find_shapelet_quality_f_stat(
X,
Comment on lines +590 to +593
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is having a new function for each new quality measure necessary? Feel like this could be done better

y,
shapelet,
sorted_indicies,
position,
length,
dim,
inst_idx,
this_cls_count,
other_cls_count,
):
"""Find shapelet quality using F-statistic.

Parameters
----------
X : list of np.ndarray
The training input samples.
y : np.ndarray
The class labels.
shapelet : np.ndarray
The shapelet to evaluate.
sorted_indicies : np.ndarray
Sorted indices for early abandon optimization.
position : int
Position of the shapelet in the original series.
length : int
Length of the shapelet.
dim : int
Dimension/channel index.
inst_idx : int
Index of the instance the shapelet was extracted from.
this_cls_count : int
Number of cases in the same class as the shapelet.
other_cls_count : int
Number of cases in other classes.

Returns
-------
quality : float
The F-statistic quality measure.
distances : np.ndarray
Array of distances from each series to the shapelet.
"""
distances = np.zeros(len(X))
distances1 = np.zeros(this_cls_count - 1) # -1 because we exclude inst_idx
distances2 = np.zeros(other_cls_count)
c1 = 0
c2 = 0

for i, series in enumerate(X):
if i != inst_idx:
distance = _online_shapelet_distance(
series[dim], shapelet, sorted_indicies, position, length
)
distances[i] = distance

if y[i] == y[inst_idx]:
if c1 < len(distances1):
distances1[c1] = distance
c1 += 1
else:
if c2 < len(distances2):
distances2[c2] = distance
c2 += 1
else:
distances[i] = 0.0

# Calculate F-statistic
quality = f_statistic(distances1[:c1], distances2[:c2])
return quality, distances

@staticmethod
@njit(fastmath=True, cache=True)
def _merge_shapelets(
Expand Down
Loading