1+ """
2+ Validation decorators for BlueMath_tk classes.
3+
4+ This module provides decorators to validate input data for various
5+ clustering, reduction, and analysis methods.
6+ """
7+
18import functools
2- from typing import Any , Dict , List
9+ from typing import Any
310
411import pandas as pd
512import xarray as xr
613
714
815def validate_data_lhs (func ):
916 """
10- Decorator to validate data in LHS class fit method.
17+ Validate data in LHS class fit method.
1118
1219 Parameters
1320 ----------
@@ -23,9 +30,9 @@ def validate_data_lhs(func):
2330 @functools .wraps (func )
2431 def wrapper (
2532 self ,
26- dimensions_names : List [str ],
27- lower_bounds : List [float ],
28- upper_bounds : List [float ],
33+ dimensions_names : list [str ],
34+ lower_bounds : list [float ],
35+ upper_bounds : list [float ],
2936 num_samples : int ,
3037 ):
3138 if not isinstance (dimensions_names , list ):
@@ -38,7 +45,8 @@ def wrapper(
3845 upper_bounds
3946 ):
4047 raise ValueError (
41- "Dimensions names, lower bounds and upper bounds must have the same length"
48+ "Dimensions names, lower bounds and upper bounds "
49+ "must have the same length"
4250 )
4351 if not all (
4452 [lower <= upper for lower , upper in zip (lower_bounds , upper_bounds )]
@@ -53,7 +61,7 @@ def wrapper(
5361
5462def validate_data_mda (func ):
5563 """
56- Decorator to validate data in MDA class fit method.
64+ Validate data in MDA class fit method.
5765
5866 Parameters
5967 ----------
@@ -70,7 +78,7 @@ def validate_data_mda(func):
7078 def wrapper (
7179 self ,
7280 data : pd .DataFrame ,
73- directional_variables : List [str ] = [],
81+ directional_variables : list [str ] = [],
7482 custom_scale_factor : dict = {},
7583 first_centroid_seed : int = None ,
7684 normalize_data : bool = False ,
@@ -90,7 +98,8 @@ def wrapper(
9098 or first_centroid_seed > data .shape [0 ]
9199 ):
92100 raise ValueError (
93- "First centroid seed must be an integer >= 0 and < num of data points"
101+ "First centroid seed must be an integer >= 0 "
102+ "and < num of data points"
94103 )
95104 if not isinstance (normalize_data , bool ):
96105 raise TypeError ("Normalize data must be a boolean" )
@@ -108,7 +117,7 @@ def wrapper(
108117
109118def validate_data_kma (func ):
110119 """
111- Decorator to validate data in KMA class fit method.
120+ Validate data in KMA class fit method.
112121
113122 Parameters
114123 ----------
@@ -125,12 +134,13 @@ def validate_data_kma(func):
125134 def wrapper (
126135 self ,
127136 data : pd .DataFrame ,
128- directional_variables : List [str ] = [],
137+ directional_variables : list [str ] = [],
129138 custom_scale_factor : dict = {},
130139 min_number_of_points : int = None ,
131140 max_number_of_iterations : int = 10 ,
132141 normalize_data : bool = False ,
133- regression_guided : Dict [str , List ] = {},
142+ regression_guided : dict [str , list ] = {},
143+ init_mda_centroids : pd .DataFrame = None ,
134144 ):
135145 if data is None :
136146 raise ValueError ("data cannot be None" )
@@ -157,7 +167,8 @@ def wrapper(
157167 for var in regression_guided .get ("vars" , [])
158168 ):
159169 raise TypeError (
160- "regression_guided vars must be a list of strings and must exist in data"
170+ "regression_guided vars must be a list of strings "
171+ "and must exist in data"
161172 )
162173 if not all (
163174 isinstance (alpha , float ) and alpha >= 0 and alpha <= 1
@@ -175,14 +186,15 @@ def wrapper(
175186 max_number_of_iterations ,
176187 normalize_data ,
177188 regression_guided ,
189+ init_mda_centroids ,
178190 )
179191
180192 return wrapper
181193
182194
183195def validate_data_som (func ):
184196 """
185- Decorator to validate data in SOM class fit method.
197+ Validate data in SOM class fit method.
186198
187199 Parameters
188200 ----------
@@ -199,7 +211,7 @@ def validate_data_som(func):
199211 def wrapper (
200212 self ,
201213 data : pd .DataFrame ,
202- directional_variables : List [str ] = [],
214+ directional_variables : list [str ] = [],
203215 custom_scale_factor : dict = {},
204216 num_iteration : int = 1000 ,
205217 normalize_data : bool = False ,
@@ -230,7 +242,7 @@ def wrapper(
230242
231243def validate_data_pca (func ):
232244 """
233- Decorator to validate data in PCA class fit method.
245+ Validate data in PCA class fit method.
234246
235247 Parameters
236248 ----------
@@ -247,8 +259,8 @@ def validate_data_pca(func):
247259 def wrapper (
248260 self ,
249261 data : xr .Dataset ,
250- vars_to_stack : List [str ],
251- coords_to_stack : List [str ],
262+ vars_to_stack : list [str ],
263+ coords_to_stack : list [str ],
252264 pca_dim_for_rows : str ,
253265 windows_in_pca_dim_for_rows : dict = {},
254266 value_to_replace_nans : dict = {},
@@ -263,18 +275,22 @@ def wrapper(
263275 for var in vars_to_stack :
264276 if var not in data .data_vars :
265277 raise ValueError (f"Variable { var } not found in data" )
266- # Check that all variables in vars_to_stack have the same coordinates and dimensions
278+ # Check that all variables in vars_to_stack have the same
279+ # coordinates and dimensions
280+ first_var = vars_to_stack [0 ]
267281 first_var = vars_to_stack [0 ]
268282 first_var_dims = list (data [first_var ].dims )
269283 first_var_coords = list (data [first_var ].coords )
270284 for var in vars_to_stack :
271285 if list (data [var ].dims ) != first_var_dims :
272286 raise ValueError (
273- f"All variables must have the same dimensions. Variable { var } does not match."
287+ f"All variables must have the same dimensions. "
288+ f"Variable { var } does not match."
274289 )
275290 if list (data [var ].coords ) != first_var_coords :
276291 raise ValueError (
277- f"All variables must have the same coordinates. Variable { var } does not match."
292+ f"All variables must have the same coordinates. "
293+ f"Variable { var } does not match."
278294 )
279295 # Check that all coords_to_stack are in the data
280296 if not isinstance (coords_to_stack , list ) or len (coords_to_stack ) == 0 :
@@ -285,7 +301,8 @@ def wrapper(
285301 # Check that pca_dim_for_rows is in the data, and window > 0 if provided
286302 if not isinstance (pca_dim_for_rows , str ) or pca_dim_for_rows not in data .dims :
287303 raise ValueError (
288- "PCA dimension for rows must be a string and found in the data dimensions"
304+ "PCA dimension for rows must be a string "
305+ "and found in the data dimensions"
289306 )
290307 for variable , windows in windows_in_pca_dim_for_rows .items ():
291308 if not isinstance (windows , list ):
@@ -314,7 +331,7 @@ def wrapper(
314331
315332def validate_data_rbf (func ):
316333 """
317- Decorator to validate data in RBF class fit method.
334+ Validate data in RBF class fit method.
318335
319336 Parameters
320337 ----------
@@ -332,8 +349,8 @@ def wrapper(
332349 self ,
333350 subset_data : pd .DataFrame ,
334351 target_data : pd .DataFrame ,
335- subset_directional_variables : List [str ] = [],
336- target_directional_variables : List [str ] = [],
352+ subset_directional_variables : list [str ] = [],
353+ target_directional_variables : list [str ] = [],
337354 subset_custom_scale_factor : dict = {},
338355 normalize_target_data : bool = True ,
339356 target_custom_scale_factor : dict = {},
@@ -353,14 +370,16 @@ def wrapper(
353370 for directional_variable in subset_directional_variables :
354371 if directional_variable not in subset_data .columns :
355372 raise ValueError (
356- f"Directional variable { directional_variable } not found in subset data"
373+ f"Directional variable { directional_variable } "
374+ f"not found in subset data"
357375 )
358376 if not isinstance (target_directional_variables , list ):
359377 raise TypeError ("Target directional variables must be a list" )
360378 for directional_variable in target_directional_variables :
361379 if directional_variable not in target_data .columns :
362380 raise ValueError (
363- f"Directional variable { directional_variable } not found in target data"
381+ f"Directional variable { directional_variable } "
382+ f"not found in target data"
364383 )
365384 if not isinstance (subset_custom_scale_factor , dict ):
366385 raise TypeError ("Subset custom scale factor must be a dict" )
@@ -391,7 +410,7 @@ def wrapper(
391410
392411def validate_data_xwt (func ):
393412 """
394- Decorator to validate data in XWT class fit method.
413+ Validate data in XWT class fit method.
395414
396415 Parameters
397416 ----------
@@ -408,7 +427,7 @@ def validate_data_xwt(func):
408427 def wrapper (
409428 self ,
410429 data : xr .Dataset ,
411- fit_params : Dict [str , Dict [str , Any ]] = {},
430+ fit_params : dict [str , dict [str , Any ]] = {},
412431 variable_to_sort_bmus : str = None ,
413432 ):
414433 if not isinstance (data , xr .Dataset ):
@@ -427,7 +446,8 @@ def wrapper(
427446 or variable_to_sort_bmus not in data .data_vars
428447 ):
429448 raise TypeError (
430- "variable_to_sort_bmus must be a string and must exist in data variables"
449+ "variable_to_sort_bmus must be a string "
450+ "and must exist in data variables"
431451 )
432452 return func (
433453 self ,
@@ -441,7 +461,7 @@ def wrapper(
441461
442462def validate_data_calval (func ):
443463 """
444- Decorator to validate data in CalVal class fit method.
464+ Validate data in CalVal class fit method.
445465
446466 Parameters
447467 ----------
0 commit comments