Skip to content

Commit 2860023

Browse files
authored
Merge pull request #142 from GeoOcean/feature/generalized-kma
[JTH] add kma generalized for more algorithms
2 parents 2f8a682 + 15f4ec4 commit 2860023

6 files changed

Lines changed: 1492 additions & 290 deletions

File tree

.github/workflows/python-tests.yml

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,27 @@ on:
99
- synchronize
1010

1111
jobs:
12+
# Separate linting job for faster feedback
13+
lint:
14+
runs-on: ubuntu-latest
15+
steps:
16+
- name: Checkout code
17+
uses: actions/checkout@v4
18+
19+
- name: Set up Python
20+
uses: actions/setup-python@v5
21+
with:
22+
python-version: "3.11"
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install .[tests]
28+
29+
- name: Lint with ruff
30+
run: |
31+
ruff check bluemath_tk/ || true # TODO: Remove || true once docstrings are fixed
32+
1233
python-tests:
1334
runs-on: ${{ matrix.os }}
1435

@@ -32,10 +53,6 @@ jobs:
3253
python -m pip install --upgrade pip
3354
pip install .[all,tests]
3455
35-
- name: Lint
36-
run: |
37-
ruff check bluemath_tk/datamining/ || true # optional: don't fail lint for now
38-
3956
- name: Run tests
4057
run: |
4158
pytest -s -v tests

bluemath_tk/core/decorators.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1+
"""
2+
Validation decorators for BlueMath_tk classes.
3+
4+
This module provides decorators to validate input data for various
5+
clustering, reduction, and analysis methods.
6+
"""
7+
18
import functools
2-
from typing import Any, Dict, List
9+
from typing import Any
310

411
import pandas as pd
512
import xarray as xr
613

714

815
def validate_data_lhs(func):
916
"""
10-
Decorator to validate data in LHS class fit method.
17+
Validate data in LHS class fit method.
1118
1219
Parameters
1320
----------
@@ -23,9 +30,9 @@ def validate_data_lhs(func):
2330
@functools.wraps(func)
2431
def wrapper(
2532
self,
26-
dimensions_names: List[str],
27-
lower_bounds: List[float],
28-
upper_bounds: List[float],
33+
dimensions_names: list[str],
34+
lower_bounds: list[float],
35+
upper_bounds: list[float],
2936
num_samples: int,
3037
):
3138
if not isinstance(dimensions_names, list):
@@ -38,7 +45,8 @@ def wrapper(
3845
upper_bounds
3946
):
4047
raise ValueError(
41-
"Dimensions names, lower bounds and upper bounds must have the same length"
48+
"Dimensions names, lower bounds and upper bounds "
49+
"must have the same length"
4250
)
4351
if not all(
4452
[lower <= upper for lower, upper in zip(lower_bounds, upper_bounds)]
@@ -53,7 +61,7 @@ def wrapper(
5361

5462
def validate_data_mda(func):
5563
"""
56-
Decorator to validate data in MDA class fit method.
64+
Validate data in MDA class fit method.
5765
5866
Parameters
5967
----------
@@ -70,7 +78,7 @@ def validate_data_mda(func):
7078
def wrapper(
7179
self,
7280
data: pd.DataFrame,
73-
directional_variables: List[str] = [],
81+
directional_variables: list[str] = [],
7482
custom_scale_factor: dict = {},
7583
first_centroid_seed: int = None,
7684
normalize_data: bool = False,
@@ -90,7 +98,8 @@ def wrapper(
9098
or first_centroid_seed > data.shape[0]
9199
):
92100
raise ValueError(
93-
"First centroid seed must be an integer >= 0 and < num of data points"
101+
"First centroid seed must be an integer >= 0 "
102+
"and < num of data points"
94103
)
95104
if not isinstance(normalize_data, bool):
96105
raise TypeError("Normalize data must be a boolean")
@@ -108,7 +117,7 @@ def wrapper(
108117

109118
def validate_data_kma(func):
110119
"""
111-
Decorator to validate data in KMA class fit method.
120+
Validate data in KMA class fit method.
112121
113122
Parameters
114123
----------
@@ -125,12 +134,13 @@ def validate_data_kma(func):
125134
def wrapper(
126135
self,
127136
data: pd.DataFrame,
128-
directional_variables: List[str] = [],
137+
directional_variables: list[str] = [],
129138
custom_scale_factor: dict = {},
130139
min_number_of_points: int = None,
131140
max_number_of_iterations: int = 10,
132141
normalize_data: bool = False,
133-
regression_guided: Dict[str, List] = {},
142+
regression_guided: dict[str, list] = {},
143+
init_mda_centroids: pd.DataFrame = None,
134144
):
135145
if data is None:
136146
raise ValueError("data cannot be None")
@@ -157,7 +167,8 @@ def wrapper(
157167
for var in regression_guided.get("vars", [])
158168
):
159169
raise TypeError(
160-
"regression_guided vars must be a list of strings and must exist in data"
170+
"regression_guided vars must be a list of strings "
171+
"and must exist in data"
161172
)
162173
if not all(
163174
isinstance(alpha, float) and alpha >= 0 and alpha <= 1
@@ -175,14 +186,15 @@ def wrapper(
175186
max_number_of_iterations,
176187
normalize_data,
177188
regression_guided,
189+
init_mda_centroids,
178190
)
179191

180192
return wrapper
181193

182194

183195
def validate_data_som(func):
184196
"""
185-
Decorator to validate data in SOM class fit method.
197+
Validate data in SOM class fit method.
186198
187199
Parameters
188200
----------
@@ -199,7 +211,7 @@ def validate_data_som(func):
199211
def wrapper(
200212
self,
201213
data: pd.DataFrame,
202-
directional_variables: List[str] = [],
214+
directional_variables: list[str] = [],
203215
custom_scale_factor: dict = {},
204216
num_iteration: int = 1000,
205217
normalize_data: bool = False,
@@ -230,7 +242,7 @@ def wrapper(
230242

231243
def validate_data_pca(func):
232244
"""
233-
Decorator to validate data in PCA class fit method.
245+
Validate data in PCA class fit method.
234246
235247
Parameters
236248
----------
@@ -247,8 +259,8 @@ def validate_data_pca(func):
247259
def wrapper(
248260
self,
249261
data: xr.Dataset,
250-
vars_to_stack: List[str],
251-
coords_to_stack: List[str],
262+
vars_to_stack: list[str],
263+
coords_to_stack: list[str],
252264
pca_dim_for_rows: str,
253265
windows_in_pca_dim_for_rows: dict = {},
254266
value_to_replace_nans: dict = {},
@@ -263,18 +275,22 @@ def wrapper(
263275
for var in vars_to_stack:
264276
if var not in data.data_vars:
265277
raise ValueError(f"Variable {var} not found in data")
266-
# Check that all variables in vars_to_stack have the same coordinates and dimensions
278+
# Check that all variables in vars_to_stack have the same
279+
# coordinates and dimensions
280+
first_var = vars_to_stack[0]
267281
first_var = vars_to_stack[0]
268282
first_var_dims = list(data[first_var].dims)
269283
first_var_coords = list(data[first_var].coords)
270284
for var in vars_to_stack:
271285
if list(data[var].dims) != first_var_dims:
272286
raise ValueError(
273-
f"All variables must have the same dimensions. Variable {var} does not match."
287+
f"All variables must have the same dimensions. "
288+
f"Variable {var} does not match."
274289
)
275290
if list(data[var].coords) != first_var_coords:
276291
raise ValueError(
277-
f"All variables must have the same coordinates. Variable {var} does not match."
292+
f"All variables must have the same coordinates. "
293+
f"Variable {var} does not match."
278294
)
279295
# Check that all coords_to_stack are in the data
280296
if not isinstance(coords_to_stack, list) or len(coords_to_stack) == 0:
@@ -285,7 +301,8 @@ def wrapper(
285301
# Check that pca_dim_for_rows is in the data, and window > 0 if provided
286302
if not isinstance(pca_dim_for_rows, str) or pca_dim_for_rows not in data.dims:
287303
raise ValueError(
288-
"PCA dimension for rows must be a string and found in the data dimensions"
304+
"PCA dimension for rows must be a string "
305+
"and found in the data dimensions"
289306
)
290307
for variable, windows in windows_in_pca_dim_for_rows.items():
291308
if not isinstance(windows, list):
@@ -314,7 +331,7 @@ def wrapper(
314331

315332
def validate_data_rbf(func):
316333
"""
317-
Decorator to validate data in RBF class fit method.
334+
Validate data in RBF class fit method.
318335
319336
Parameters
320337
----------
@@ -332,8 +349,8 @@ def wrapper(
332349
self,
333350
subset_data: pd.DataFrame,
334351
target_data: pd.DataFrame,
335-
subset_directional_variables: List[str] = [],
336-
target_directional_variables: List[str] = [],
352+
subset_directional_variables: list[str] = [],
353+
target_directional_variables: list[str] = [],
337354
subset_custom_scale_factor: dict = {},
338355
normalize_target_data: bool = True,
339356
target_custom_scale_factor: dict = {},
@@ -353,14 +370,16 @@ def wrapper(
353370
for directional_variable in subset_directional_variables:
354371
if directional_variable not in subset_data.columns:
355372
raise ValueError(
356-
f"Directional variable {directional_variable} not found in subset data"
373+
f"Directional variable {directional_variable} "
374+
f"not found in subset data"
357375
)
358376
if not isinstance(target_directional_variables, list):
359377
raise TypeError("Target directional variables must be a list")
360378
for directional_variable in target_directional_variables:
361379
if directional_variable not in target_data.columns:
362380
raise ValueError(
363-
f"Directional variable {directional_variable} not found in target data"
381+
f"Directional variable {directional_variable} "
382+
f"not found in target data"
364383
)
365384
if not isinstance(subset_custom_scale_factor, dict):
366385
raise TypeError("Subset custom scale factor must be a dict")
@@ -391,7 +410,7 @@ def wrapper(
391410

392411
def validate_data_xwt(func):
393412
"""
394-
Decorator to validate data in XWT class fit method.
413+
Validate data in XWT class fit method.
395414
396415
Parameters
397416
----------
@@ -408,7 +427,7 @@ def validate_data_xwt(func):
408427
def wrapper(
409428
self,
410429
data: xr.Dataset,
411-
fit_params: Dict[str, Dict[str, Any]] = {},
430+
fit_params: dict[str, dict[str, Any]] = {},
412431
variable_to_sort_bmus: str = None,
413432
):
414433
if not isinstance(data, xr.Dataset):
@@ -427,7 +446,8 @@ def wrapper(
427446
or variable_to_sort_bmus not in data.data_vars
428447
):
429448
raise TypeError(
430-
"variable_to_sort_bmus must be a string and must exist in data variables"
449+
"variable_to_sort_bmus must be a string "
450+
"and must exist in data variables"
431451
)
432452
return func(
433453
self,
@@ -441,7 +461,7 @@ def wrapper(
441461

442462
def validate_data_calval(func):
443463
"""
444-
Decorator to validate data in CalVal class fit method.
464+
Validate data in CalVal class fit method.
445465
446466
Parameters
447467
----------

0 commit comments

Comments
 (0)