Skip to content

Commit ba8a74d

Browse files
authored
Merge pull request #69 from NERC-CEH/feature/FPM-586-remove-base-flag-column
2 parents cb4a433 + c1d7243 commit ba8a74d

5 files changed

Lines changed: 100 additions & 120 deletions

File tree

src/time_stream/base.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -438,34 +438,29 @@ def get_flag_system(self, name: str) -> type[BitwiseFlag]:
438438
"""
439439
return self._flag_manager.get_flag_system(name)
440440

441-
def register_flag_column(self, column_name: str, base: str, flag_system: str) -> None:
441+
def register_flag_column(self, column_name: str, flag_system: str) -> None:
442442
"""Mark the specified existing column as a flag column.
443443
444-
This does not modify the DataFrame; it only records that ``name`` is a flag column associated with the
445-
value column ``base``, with values handled by the flag system ``flag_system``.
444+
This does not modify the DataFrame; it only records that ``column_name`` is a flag column
445+
with values handled by the flag system ``flag_system``.
446446
447447
Args:
448448
column_name: A column name to mark as a flag column.
449-
base: Name of the value/data column this flag column refers to.
450449
flag_system: The name of the flag system.
451450
"""
452-
check_columns_in_dataframe(self.df, [column_name, base])
453-
self._flag_manager.register_flag_column(column_name, base, flag_system)
451+
check_columns_in_dataframe(self.df, column_name)
452+
self._flag_manager.register_flag_column(column_name, flag_system)
454453

455-
def init_flag_column(
456-
self, base: str, flag_system: str, column_name: str | None = None, data: int | Sequence[int] = 0
457-
) -> None:
454+
def init_flag_column(self, flag_system: str, column_name: str | None = None, data: int | Sequence[int] = 0) -> None:
458455
"""Add a new column to the TimeFrame DataFrame, setting it as a Flag Column.
459456
460457
Args:
461-
base: Name of the value/data column this flag column will refer to.
462458
flag_system: The name of the flag system.
463459
column_name: Optional name for the new flag column. If omitted, a name of the
464-
form "{base}__flag__{flag_system}" is used.
460+
form ``__flag__{flag_system}`` is used, with an integer suffix appended if the name
461+
already exists (e.g. ``__flag__CORE_FLAGS__1``).
465462
data: The default value to populate the flag column with. Can be a scalar or list-like. Defaults to 0.
466463
"""
467-
check_columns_in_dataframe(self.df, base)
468-
469464
# Validate that the flag system exists
470465
self.get_flag_system(flag_system)
471466

@@ -477,11 +472,16 @@ def init_flag_column(
477472

478473
# Determine name of flag column
479474
if not column_name:
480-
column_name = f"{base}__flag__{flag_system}"
475+
column_name = f"__flag__{flag_system}"
476+
if column_name in self.df.columns:
477+
col_suffix = 1
478+
while f"{column_name}__{col_suffix}" in self.df.columns:
479+
col_suffix += 1
480+
column_name = f"{column_name}__{col_suffix}"
481481

482482
# Add and register as a flag column
483483
self._df = self.df.with_columns(data.alias(column_name))
484-
self.register_flag_column(column_name, base, flag_system)
484+
self.register_flag_column(column_name, flag_system)
485485
self._column_metadata.sync()
486486

487487
def get_flag_column(self, flag_column_name: str) -> FlagColumn:
@@ -769,19 +769,18 @@ def infill(
769769
def select(
770770
self,
771771
column_names: str | Sequence[str],
772-
include_flag_columns: bool = True,
773772
) -> Self:
774773
"""Return a new TimeFrame instance to include only the specified columns.
775774
776-
By default, this:
775+
This:
777776
- carries over TimeFrame-level metadata,
778777
- prunes column-level metadata to the kept columns,
779778
- rebuilds the flag manager to include only kept flag columns.
780779
780+
Flag columns are not automatically included; name them explicitly if you want them retained.
781+
781782
Args:
782783
column_names: Column name(s) to retain in the updated TimeFrame.
783-
include_flag_columns: If True, include any registered flag columns whose base is among the
784-
kept value columns.
785784
786785
Returns:
787786
New TimeFrame instance with only selected columns.
@@ -797,13 +796,6 @@ def select(
797796
if self.time_name not in column_names:
798797
column_names.insert(0, self.time_name)
799798

800-
# Optionally include associated flag columns
801-
if include_flag_columns:
802-
for flag_name, flag_column in self._flag_manager.flag_columns.items():
803-
# include if its base (value col) is being kept
804-
if flag_column.base in column_names:
805-
column_names.append(flag_name)
806-
807799
# Build new frame
808800
new_df = self.df.select(column_names)
809801

@@ -824,9 +816,7 @@ def select(
824816
# keep only flag columns that survived
825817
for flag_name, flag_column in self._flag_manager.flag_columns.items():
826818
if flag_name in column_names:
827-
new_flag_manager.register_flag_column(
828-
flag_name, flag_column.base, flag_column.flag_system.system_name()
829-
)
819+
new_flag_manager.register_flag_column(flag_name, flag_column.flag_system.system_name())
830820

831821
tf._flag_manager = new_flag_manager
832822
tf._column_metadata.sync()
@@ -868,7 +858,7 @@ def __getitem__(self, key: str | Sequence[str]) -> Self:
868858
"""
869859
if isinstance(key, str):
870860
key = [key]
871-
return self.select(key, include_flag_columns=False)
861+
return self.select(key)
872862

873863
def __str__(self) -> str:
874864
"""Return the string representation of the TimeFrame dataframe."""

src/time_stream/flag_manager.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
Typical use from within the TimeFrame class:
88
1) Register a flag system by name (from a dict or an existing BitwiseFlag subclass).
9-
2) Register a DataFrame column as a flag column associated with that system.
9+
2) Register a DataFrame column as a flag column linked to that system.
1010
3) Use FlagColumn.add_flag / remove_flag with Polars expressions to set/clear the flag bits.
1111
"""
1212

@@ -31,17 +31,14 @@
3131
class FlagColumn:
3232
"""Represents a flag column in a TimeFrame.
3333
34-
A flag column stores bitwise flags governed by a specific flag system. Each flag column is associated with a base
35-
data column.
34+
A flag column stores bitwise flags governed by a specific flag system.
3635
3736
Attributes:
3837
name: Name of the flag column in the DataFrame.
39-
base: Name of the associated value/data column.
4038
flag_system: The enum class that defines the available flags and their bit values.
4139
"""
4240

4341
name: str
44-
base: str
4542
flag_system: type[BitwiseFlag]
4643

4744
def add_flag(self, df: pl.DataFrame, flag: int | str, expr: pl.Expr = pl.lit(True)) -> pl.DataFrame:
@@ -90,7 +87,7 @@ def __eq__(self, other: object) -> bool:
9087
if not isinstance(other, FlagColumn):
9188
return False
9289

93-
return self.name == other.name and self.base == other.base and self.flag_system == other.flag_system
90+
return self.name == other.name and self.flag_system == other.flag_system
9491

9592
# Make class instances unhashable
9693
__hash__ = None
@@ -101,7 +98,7 @@ class FlagManager:
10198
10299
This class:
103100
* registers **flag systems** (bit registries) under a string name;
104-
* registers **flag columns** that reference a base data column and a specific flag system.
101+
* registers **flag columns** linked to a specific flag system.
105102
"""
106103

107104
def __init__(self) -> None:
@@ -173,22 +170,19 @@ def get_flag_system(self, flag_system_name: str) -> type[BitwiseFlag]:
173170
except KeyError:
174171
raise FlagSystemNotFoundError(f"No such flag system: '{flag_system_name}'")
175172

176-
def register_flag_column(self, name: str, base: str, flag_system_name: str) -> None:
173+
def register_flag_column(self, name: str, flag_system_name: str) -> None:
177174
"""Mark the specified existing column as a flag column.
178175
179176
Args:
180177
name: A column name to mark as a flag column.
181-
base: Name of the value/data column this flag column refers to.
182178
flag_system_name: The name of the flag system.
183179
"""
184180
flag_column = self._flag_columns.get(name)
185181
if flag_column:
186-
raise FlagSystemError(
187-
f"Flag column '{name}' already registered. Base: '{flag_column.base}'; System: '{flag_system_name}'."
188-
)
182+
raise FlagSystemError(f"Flag column '{name}' already registered. System: '{flag_system_name}'.")
189183
else:
190184
flag_system = self.get_flag_system(flag_system_name)
191-
flag_column = FlagColumn(name, base, flag_system)
185+
flag_column = FlagColumn(name, flag_system)
192186
self._flag_columns[name] = flag_column
193187

194188
def get_flag_column(self, name: str) -> FlagColumn:
@@ -215,7 +209,7 @@ def copy(self) -> Self:
215209

216210
# register flag columns in the new copy with their associated system names
217211
for name, flag_column in self._flag_columns.items():
218-
out.register_flag_column(name, flag_column.base, flag_column.flag_system.system_name())
212+
out.register_flag_column(name, flag_column.flag_system.system_name())
219213

220214
return out
221215

tests/time_stream/test_base.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,23 @@ def test_select_column_doesnt_mutate_original_tf(self) -> None:
132132
assert col2_tf == expected
133133
assert_frame_equal(tf.df, original_df)
134134

135-
def test_select_column_with_flags(self) -> None:
135+
def test_select_column_does_not_auto_include_flags(self) -> None:
136+
"""Flag columns are not automatically included when selecting a data column."""
136137
tf = TimeFrame(self.df, time_name="time").with_flag_system("system", {"A": 1, "B": 2, "C": 4})
137-
tf.init_flag_column("col1", "system", "flag_col")
138-
139-
expected = TimeFrame(self.df.select(["time", "col1"]), time_name="time").with_flag_system(
140-
"system", {"A": 1, "B": 2, "C": 4}
141-
)
142-
expected.init_flag_column("col1", "system", "flag_col")
138+
tf.init_flag_column("system", "flag_col")
143139

144140
result = tf.select(["col1"])
145-
assert result == expected
141+
assert "flag_col" not in result.df.columns
142+
assert "flag_col" not in result.flag_columns
143+
144+
def test_select_column_with_explicit_flag_column(self) -> None:
145+
"""Flag columns are preserved when explicitly included in select."""
146+
tf = TimeFrame(self.df, time_name="time").with_flag_system("system", {"A": 1, "B": 2, "C": 4})
147+
tf.init_flag_column("system", "flag_col")
148+
149+
result = tf.select(["col1", "flag_col"])
150+
assert "flag_col" in result.df.columns
151+
assert "flag_col" in result.flag_columns
146152

147153

148154
class TestGetItem:
@@ -370,15 +376,15 @@ def test_init_flag_column_success(self) -> None:
370376
# Shouldn't have any flag columns to start with
371377
assert tf._flag_manager._flag_columns == {}
372378

373-
tf.init_flag_column("value", "quality_control", "flag_column")
374-
assert tf._flag_manager._flag_columns == {"flag_column": FlagColumn("flag_column", "value", flag_system)}
379+
tf.init_flag_column("quality_control", "flag_column")
380+
assert tf._flag_manager._flag_columns == {"flag_column": FlagColumn("flag_column", flag_system)}
375381

376382
assert "flag_column" in tf.flag_columns
377383

378384
def test_init_column_adds_to_metadata(self) -> None:
379385
"""Test that a new flag column gets added to the column metadata dict."""
380386
tf, flag_system = self.setup_tf()
381-
tf.init_flag_column("value", "quality_control", "flag_column")
387+
tf.init_flag_column("quality_control", "flag_column")
382388
expected = {"time": {}, "value": {}, "existing_flags": {}, "flag_column": {}}
383389
assert tf.column_metadata == expected
384390

@@ -389,7 +395,7 @@ def test_init_flag_column_with_single_value(self) -> None:
389395
# Shouldn't have any flag columns to start with
390396
assert tf._flag_manager._flag_columns == {}
391397

392-
tf.init_flag_column("value", "quality_control", "flag_column", 1)
398+
tf.init_flag_column("quality_control", "flag_column", 1)
393399
expected_values = pl.Series("flag_column", [1, 1, 1], dtype=pl.Int64)
394400
assert_series_equal(tf.df["flag_column"], expected_values)
395401

@@ -400,7 +406,7 @@ def test_init_flag_column_with_list_value(self) -> None:
400406
# Shouldn't have any flag columns to start with
401407
assert tf._flag_manager._flag_columns == {}
402408

403-
tf.init_flag_column("value", "quality_control", "flag_column", [1, 2, 4])
409+
tf.init_flag_column("quality_control", "flag_column", [1, 2, 4])
404410
expected_values = pl.Series("flag_column", [1, 2, 4], dtype=pl.Int64)
405411
assert_series_equal(tf.df["flag_column"], expected_values)
406412

@@ -411,21 +417,43 @@ def test_with_no_flag_column_name(self) -> None:
411417
# Shouldn't have any flag columns to start with
412418
assert tf._flag_manager._flag_columns == {}
413419

414-
tf.init_flag_column("value", "quality_control")
415-
default_name = "value__flag__quality_control"
420+
tf.init_flag_column("quality_control")
421+
default_name = "__flag__quality_control"
416422

417-
assert tf._flag_manager._flag_columns == {default_name: FlagColumn(default_name, "value", flag_system)}
423+
assert tf._flag_manager._flag_columns == {default_name: FlagColumn(default_name, flag_system)}
418424

419425
assert default_name in tf.flag_columns
420426

421427
def test_init_no_column_name_adds_to_metadata(self) -> None:
422428
"""Test that a new flag column gets added to the column metadata dict."""
423429
tf, flag_system = self.setup_tf()
424430

425-
tf.init_flag_column("value", "quality_control")
426-
expected = {"time": {}, "value": {}, "existing_flags": {}, "value__flag__quality_control": {}}
431+
tf.init_flag_column("quality_control")
432+
expected = {"time": {}, "value": {}, "existing_flags": {}, "__flag__quality_control": {}}
427433
assert tf.column_metadata == expected
428434

435+
def test_auto_name_collision_avoidance(self) -> None:
436+
"""Test that a numeric suffix is appended when the auto-generated name already exists."""
437+
tf, flag_system = self.setup_tf()
438+
439+
tf.init_flag_column("quality_control")
440+
tf.init_flag_column("quality_control")
441+
442+
assert "__flag__quality_control" in tf.flag_columns
443+
assert "__flag__quality_control__1" in tf.flag_columns
444+
445+
def test_auto_name_collision_avoidance_multiple(self) -> None:
446+
"""Test that the numeric suffix increments correctly across multiple collisions."""
447+
tf, flag_system = self.setup_tf()
448+
449+
tf.init_flag_column("quality_control")
450+
tf.init_flag_column("quality_control")
451+
tf.init_flag_column("quality_control")
452+
453+
assert "__flag__quality_control" in tf.flag_columns
454+
assert "__flag__quality_control__1" in tf.flag_columns
455+
assert "__flag__quality_control__2" in tf.flag_columns
456+
429457

430458
class TestTimeSeriesEquality:
431459
def setUp(self) -> None:

0 commit comments

Comments
 (0)