Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/time_stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- Register reusable flag systems (`BitwiseFlag` or `CategoricalFlag`).
- Initialise flag columns linked to data columns.
- Add/remove flags with Polars expressions.
- Filter TimeFrame based on flag values

4. **Data operations**:
- Aggregation: run `AggregationFunction` pipelines with support for missing-data criteria and time anchoring.
Expand Down Expand Up @@ -663,6 +664,38 @@ def encode_flag_column(self, flag_column_name: str) -> TimeFrame:
tf._flag_manager.flag_columns[flag_column_name].is_decoded = False
return tf

def filter_by_flag(
self,
flag_column_name: str,
flag: int | str | list[int | str],
include: bool = True,
) -> TimeFrame:
"""Return a new TimeFrame filtered to rows that have (or lack) specific flags set.

For bitwise flag columns, a row matches if any of the given flags are set (bitwise OR check).
For categorical flag columns, a row matches if its value (or any list element in list mode) is any of the given
flag values.

Args:
flag_column_name: The name of the registered flag column to filter on.
flag: One or more flag names or values to filter against.
include: Whether to keep only rows that have the flag(s) (``True``) or keep only rows that do not have
the flag(s) (``False``)

Returns:
A new TimeFrame containing only the rows that satisfy the filter condition.
"""
flags = flag if isinstance(flag, list) else [flag]
flag_column = self.get_flag_column(flag_column_name)
expr = flag_column.filter_expr(flags)
if not include:
# Fill null ensures that rows that don't have any flag values (null) are kept
expr = ~expr.fill_null(False)
tf = self.copy()
tf._df = self.df.filter(expr)
tf._column_metadata.sync()
return tf

def aggregate(
self,
aggregation_period: Period | str,
Expand Down
78 changes: 78 additions & 0 deletions src/time_stream/flags/flag_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ def remove_flag(self, df: pl.DataFrame, flag: int | str, expr: pl.Expr = pl.lit(
"""
raise NotImplementedError

@abstractmethod
def filter_expr(self, flags: list[int | str]) -> pl.Expr:
"""Return a boolean Polars expression that is True for rows where any of the given flags are set.

Handles both encoded and decoded column states internally.

Args:
flags: One or more flag names or values to match against.

Returns:
A boolean Polars expression.
"""
raise NotImplementedError


@dataclass
class BitwiseFlagColumn(FlagColumn):
Expand Down Expand Up @@ -223,6 +237,33 @@ def remove_flag(self, df: pl.DataFrame, flag: int | str, expr: pl.Expr = pl.lit(
df = self.decode(df)
return df

def filter_expr(self, flags: list[int | str]) -> pl.Expr:
"""Return a boolean expression that is True for rows where any of the given flags are set.

For an encoded column, does a bitwise OR to check for existence.
For a decoded column, checks whether the list contains any of the flag names.

Args:
flags: One or more flag names or bit values to match against.

Returns:
A boolean Polars expression.

Raises:
BitwiseFlagUnknownError: If any flag is not in the flag system.
"""
# Fetch the actual flag enum members based on the flag values provided
flag_members = [self.flag_system.get_flag(f) for f in flags]

if self.is_decoded:
exprs = [pl.col(self.name).list.contains(pl.lit(f.name)) for f in flag_members]
return pl.any_horizontal(exprs)

combined = 0
for f in flag_members:
combined |= int(f)
return (pl.col(self.name) & pl.lit(combined)) != 0

def __eq__(self, other: object) -> bool:
"""Check if two ``BitwiseFlagColumn`` instances are equal.

Expand Down Expand Up @@ -371,6 +412,22 @@ def remove_flag(self, df: pl.DataFrame, flag: int | str, expr: pl.Expr = pl.lit(
df = self.decode(df)
return df

def filter_expr(self, flags: list[int | str]) -> pl.Expr:
"""Return a boolean expression that is True for rows where the column value matches any of the given flags.

Args:
flags: One or more flag names or values to match against.

Returns:
A boolean Polars expression.

Raises:
CategoricalFlagUnknownError: If any flag is not in the flag system.
"""
flag_members = [self.flag_system.get_flag(f) for f in flags]
values = [f.name if self.is_decoded else f.value for f in flag_members]
return pl.col(self.name).is_in(values)

def __eq__(self, other: object) -> bool:
"""Check if two ``CategoricalSingleFlagColumn`` instances are equal.

Expand Down Expand Up @@ -519,6 +576,27 @@ def remove_flag(self, df: pl.DataFrame, flag: int | str, expr: pl.Expr = pl.lit(
df = self.decode(df)
return df

def filter_expr(self, flags: list[int | str]) -> pl.Expr:
"""Return a boolean expression that is True for rows where any of the given flags are set.

For scalar mode, checks whether the column value is any of the given flag values.
For list mode, checks whether the list contains any of the given flag values.

Args:
flags: One or more flag names or values to match against.

Returns:
A boolean Polars expression.

Raises:
CategoricalFlagUnknownError: If any flag is not in the flag system.
"""
# Fetch the actual flag enum members based on the flag values provided
flag_members = [self.flag_system.get_flag(f) for f in flags]
values = [f.name if self.is_decoded else f.value for f in flag_members]
exprs = [pl.col(self.name).list.contains(pl.lit(v)) for v in values]
return pl.any_horizontal(exprs)

def __eq__(self, other: object) -> bool:
"""Check if two ``CategoricalListFlagColumn`` instances are equal.

Expand Down
199 changes: 198 additions & 1 deletion tests/time_stream/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

from time_stream.aggregation import Percentile
from time_stream.base import TimeFrame
from time_stream.exceptions import ColumnNotFoundError, FlagSystemNotFoundError, MetadataError
from time_stream.exceptions import (
BitwiseFlagUnknownError,
CategoricalFlagUnknownError,
ColumnNotFoundError,
FlagSystemNotFoundError,
MetadataError,
)
from time_stream.flags.flag_manager import BitwiseFlagColumn
from time_stream.flags.flag_system import FlagSystemBase
from time_stream.period import Period
Expand Down Expand Up @@ -829,3 +835,194 @@ def test_calculate_min_max_envelope(self) -> None:
actual_tf = tf.calculate_min_max_envelope()

assert_frame_equal(expected_df, actual_tf.df, check_column_order=False)


class TestFilterByFlag:
@staticmethod
def setup_bitwise_tf() -> TimeFrame:
tf = TimeFrame(
pl.DataFrame(
{
"time": [
datetime(2025, 1, 1),
datetime(2025, 1, 2),
datetime(2025, 1, 3),
datetime(2025, 1, 4),
datetime(2025, 1, 5),
],
"value": [10, 20, 30, 40, 50],
}
),
"time",
).with_flag_system("qc", {"FLAG_A": 1, "FLAG_B": 2, "FLAG_C": 4})
tf.init_flag_column("qc", "flag_col")
tf.add_flag("flag_col", "FLAG_A", pl.col("value").is_in([20, 40]))
tf.add_flag("flag_col", "FLAG_B", pl.col("value").is_in([30, 40]))
tf.add_flag("flag_col", "FLAG_C", pl.col("value") == 50)
return tf

@staticmethod
def setup_categorical_scalar_tf() -> TimeFrame:
tf = TimeFrame(
pl.DataFrame(
{
"time": [
datetime(2025, 1, 1),
datetime(2025, 1, 2),
datetime(2025, 1, 3),
datetime(2025, 1, 4),
datetime(2025, 1, 5),
],
"value": [10, 20, 30, 40, 50],
}
),
"time",
).with_flag_system("qc", {"FLAG_A": 0, "FLAG_B": 1, "FLAG_C": 2}, flag_type="categorical")
tf.init_flag_column("qc", "cat_flag")
tf.add_flag("cat_flag", "FLAG_A", pl.col("value").is_in([20, 50]))
tf.add_flag("cat_flag", "FLAG_B", pl.col("value") == 30)
tf.add_flag("cat_flag", "FLAG_C", pl.col("value") == 40)
return tf

@staticmethod
def setup_categorical_list_tf() -> TimeFrame:
tf = TimeFrame(
pl.DataFrame(
{
"time": [
datetime(2025, 1, 1),
datetime(2025, 1, 2),
datetime(2025, 1, 3),
datetime(2025, 1, 4),
datetime(2025, 1, 5),
],
"value": [10, 20, 30, 40, 50],
}
),
"time",
).with_flag_system("qc", {"FLAG_A": 0, "FLAG_B": 1, "FLAG_C": 2}, flag_type="categorical_list")
tf.init_flag_column("qc", "list_flag")
tf.add_flag("list_flag", "FLAG_A", pl.col("value").is_in([20, 40]))
tf.add_flag("list_flag", "FLAG_B", pl.col("value").is_in([30, 40]))
tf.add_flag("list_flag", "FLAG_C", pl.col("value") == 50)
return tf

def test_bitwise_include_single_flag_by_name(self) -> None:
"""Rows where FLAG_A is set should be returned."""
tf = self.setup_bitwise_tf()
result = tf.filter_by_flag("flag_col", "FLAG_A")
assert result.df["value"].to_list() == [20, 40]

def test_bitwise_include_single_flag_by_int(self) -> None:
"""Filtering by int value should work identically to filtering by name."""
tf = self.setup_bitwise_tf()
result = tf.filter_by_flag("flag_col", 1)
assert result.df["value"].to_list() == [20, 40]

def test_bitwise_exclude_single_flag(self) -> None:
"""Rows where FLAG_A is set should be excluded."""
tf = self.setup_bitwise_tf()
result = tf.filter_by_flag("flag_col", "FLAG_A", include=False)
assert result.df["value"].to_list() == [10, 30, 50]

def test_bitwise_include_multiple_flags(self) -> None:
"""A list of flags uses any-of logic - rows where FLAG_A or FLAG_C is set."""
tf = self.setup_bitwise_tf()
result = tf.filter_by_flag("flag_col", ["FLAG_A", "FLAG_C"])
assert result.df["value"].to_list() == [20, 40, 50]

def test_bitwise_combined_value_matched_by_constituent_flag(self) -> None:
"""A row with FLAG_A|FLAG_B combined is matched by filtering on FLAG_B alone."""
tf = self.setup_bitwise_tf()
result = tf.filter_by_flag("flag_col", "FLAG_B")
assert result.df["value"].to_list() == [30, 40]

def test_bitwise_unknown_flag_raises_error(self) -> None:
"""An unrecognised flag name should raise error."""
tf = self.setup_bitwise_tf()
with pytest.raises(BitwiseFlagUnknownError):
tf.filter_by_flag("flag_col", "FLAG_Z")

def test_bitwise_decoded_include(self) -> None:
"""filter_by_flag works on a decoded bitwise flag column."""
tf = self.setup_bitwise_tf().decode_flag_column("flag_col")
result = tf.filter_by_flag("flag_col", "FLAG_A")
assert result.df["value"].to_list() == [20, 40]

def test_bitwise_decoded_exclude(self) -> None:
"""Exclusion works correctly on a decoded bitwise flag column."""
tf = self.setup_bitwise_tf().decode_flag_column("flag_col")
result = tf.filter_by_flag("flag_col", "FLAG_A", include=False)
assert result.df["value"].to_list() == [10, 30, 50]

def test_bitwise_decoded_multiple_flags(self) -> None:
"""Multiple flags use any-of logic on a decoded bitwise flag column."""
tf = self.setup_bitwise_tf().decode_flag_column("flag_col")
result = tf.filter_by_flag("flag_col", ["FLAG_A", "FLAG_C"])
assert result.df["value"].to_list() == [20, 40, 50]

def test_categorical_scalar_include_by_name(self) -> None:
"""Rows where the flag is FLAG_A should be returned."""
tf = self.setup_categorical_scalar_tf()
result = tf.filter_by_flag("cat_flag", "FLAG_A")
assert result.df["value"].to_list() == [20, 50]

def test_categorical_scalar_include_by_value(self) -> None:
"""Filtering by int value should work identically to filtering by name."""
tf = self.setup_categorical_scalar_tf()
result = tf.filter_by_flag("cat_flag", 1)
assert result.df["value"].to_list() == [30]

def test_categorical_scalar_exclude(self) -> None:
"""Rows where the flag is FLAG_C should be excluded."""
tf = self.setup_categorical_scalar_tf()
result = tf.filter_by_flag("cat_flag", "FLAG_C", include=False)
assert result.df["value"].to_list() == [10, 20, 30, 50]

def test_categorical_scalar_multiple_flags(self) -> None:
"""Rows where the flag is FLAG_A or FLAG_B should be returned."""
tf = self.setup_categorical_scalar_tf()
result = tf.filter_by_flag("cat_flag", ["FLAG_A", "FLAG_B"])
assert result.df["value"].to_list() == [20, 30, 50]

def test_categorical_scalar_decoded_include(self) -> None:
"""filter_by_flag works on a decoded categorical scalar column."""
tf = self.setup_categorical_scalar_tf().decode_flag_column("cat_flag")
result = tf.filter_by_flag("cat_flag", "FLAG_A")
assert result.df["value"].to_list() == [20, 50]

def test_categorical_scalar_unknown_flag_raises_error(self) -> None:
"""An unrecognised flag name should raise error."""
tf = self.setup_categorical_scalar_tf()
with pytest.raises(CategoricalFlagUnknownError):
tf.filter_by_flag("cat_flag", "FLAG_Z")

def test_categorical_list_include_by_name(self) -> None:
"""Rows where FLAG_A is in the list should be returned."""
tf = self.setup_categorical_list_tf()
result = tf.filter_by_flag("list_flag", "FLAG_A")
assert result.df["value"].to_list() == [20, 40]

def test_categorical_list_exclude(self) -> None:
"""Rows where FLAG_C is in the list should be excluded."""
tf = self.setup_categorical_list_tf()
result = tf.filter_by_flag("list_flag", "FLAG_C", include=False)
assert result.df["value"].to_list() == [10, 20, 30, 40]

def test_categorical_list_multiple_flags(self) -> None:
"""Rows where FLAG_A or FLAG_C is in the list should be returned."""
tf = self.setup_categorical_list_tf()
result = tf.filter_by_flag("list_flag", ["FLAG_A", "FLAG_C"])
assert result.df["value"].to_list() == [20, 40, 50]

def test_categorical_list_decoded_include(self) -> None:
"""filter_by_flag works on a decoded categorical list column."""
tf = self.setup_categorical_list_tf().decode_flag_column("list_flag")
result = tf.filter_by_flag("list_flag", "FLAG_A")
assert result.df["value"].to_list() == [20, 40]

def test_invalid_column_name_raises_error(self) -> None:
"""Passing an unregistered column name should raise error."""
tf = self.setup_bitwise_tf()
with pytest.raises(ColumnNotFoundError):
tf.filter_by_flag("nonexistent_col", "FLAG_A")